{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# default_exp models.layers"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Layers\n",
    "\n",
    "> Helper function used to build PyTorch timeseries models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "from tsai.imports import *\n",
    "from tsai.utils import *\n",
    "from torch.nn.init import normal_\n",
    "from fastai.torch_core import Module\n",
    "from fastai.layers import *\n",
    "from torch.nn.utils import weight_norm, spectral_norm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def noop(x): return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def init_lin_zero(m):\n",
    "    if isinstance(m, (nn.Linear)): \n",
    "        if getattr(m, 'bias', None) is not None: nn.init.constant_(m.bias, 0)\n",
    "        nn.init.constant_(m.weight, 0)\n",
    "    for l in m.children(): init_lin_zero(l)\n",
    "        \n",
    "lin_zero_init = init_lin_zero"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export    \n",
    "class SwishBeta(Module):\n",
    "    def __init__(self, beta=1.): \n",
    "        self.sigmoid = torch.sigmoid\n",
    "        self.beta = nn.Parameter(torch.Tensor(1).fill_(beta))\n",
    "    def forward(self, x): return x.mul(self.sigmoid(x*self.beta))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class Chomp1d(nn.Module):\n",
    "    def __init__(self, chomp_size):\n",
    "        super(Chomp1d, self).__init__()\n",
    "        self.chomp_size = chomp_size\n",
    "        \n",
    "    def forward(self, x):\n",
    "        return x[:, :, :-self.chomp_size].contiguous()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def same_padding1d(seq_len, ks, stride=1, dilation=1):\n",
    "    \"Same padding formula as used in Tensorflow\"\n",
    "    p = (seq_len - 1) * stride + (ks - 1) * dilation + 1 - seq_len\n",
    "    return p // 2, p - p // 2\n",
    "\n",
    "\n",
    "class Pad1d(nn.ConstantPad1d):\n",
    "    def __init__(self, padding, value=0.):\n",
    "        super().__init__(padding, value)\n",
    "\n",
    "        \n",
    "@delegates(nn.Conv1d)\n",
    "class SameConv1d(Module):\n",
    "    \"Conv1d with padding='same'\"\n",
    "    def __init__(self, ni, nf, ks=3, stride=1, dilation=1, **kwargs):\n",
    "        self.ks, self.stride, self.dilation = ks, stride, dilation\n",
    "        self.conv1d_same = nn.Conv1d(ni, nf, ks, stride=stride, dilation=dilation, **kwargs)\n",
    "        self.weight = self.conv1d_same.weight\n",
    "        self.bias = self.conv1d_same.bias\n",
    "        self.pad = Pad1d\n",
    "\n",
    "    def forward(self, x):\n",
    "        self.padding = same_padding1d(x.shape[-1], self.ks, dilation=self.dilation) #stride=self.stride not used in padding calculation!\n",
    "        return self.conv1d_same(self.pad(self.padding)(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def same_padding2d(H, W, ks, stride=(1, 1), dilation=(1, 1)):\n",
    "    \"Same padding formula as used in Tensorflow\"\n",
    "    if isinstance(ks, Integral): ks = (ks, ks)\n",
    "    if ks[0] == 1:  p_h = 0\n",
    "    else:  p_h = (H - 1) * stride[0] + (ks[0] - 1) * dilation[0] + 1 - H\n",
    "    if ks[1] == 1:  p_w = 0\n",
    "    else:  p_w = (W - 1) * stride[1] + (ks[1] - 1) * dilation[1] + 1 - W\n",
    "    return (p_w // 2, p_w - p_w // 2, p_h // 2, p_h - p_h // 2)\n",
    "\n",
    "\n",
    "class Pad2d(nn.ConstantPad2d):\n",
    "    def __init__(self, padding, value=0.):\n",
    "        super().__init__(padding, value)\n",
    "\n",
    "\n",
    "@delegates(nn.Conv2d)\n",
    "class Conv2dSame(Module):\n",
    "    \"Conv2d with padding='same'\"\n",
    "    def __init__(self, ni, nf, ks=(3, 3), stride=(1, 1), dilation=(1, 1), **kwargs):\n",
    "        if isinstance(ks, Integral): ks = (ks, ks)\n",
    "        if isinstance(stride, Integral): stride = (stride, stride)\n",
    "        if isinstance(dilation, Integral): dilation = (dilation, dilation)\n",
    "        self.ks, self.stride, self.dilation = ks, stride, dilation\n",
    "        self.conv2d_same = nn.Conv2d(ni, nf, ks, stride=stride, dilation=dilation, **kwargs)\n",
    "        self.weight = self.conv2d_same.weight\n",
    "        self.bias = self.conv2d_same.bias\n",
    "        self.pad = Pad2d\n",
    "\n",
    "    def forward(self, x):\n",
    "        self.padding = same_padding2d(x.shape[-2], x.shape[-1], self.ks, dilation=self.dilation) #stride=self.stride not used in padding calculation!\n",
    "        return self.conv2d_same(self.pad(self.padding)(x))\n",
    "    \n",
    "    \n",
    "@delegates(nn.Conv2d)\n",
    "def Conv2d(ni, nf, kernel_size=None, ks=None, stride=1, padding='same', dilation=1, init='auto', bias_std=0.01, **kwargs):\n",
    "    \"conv1d layer with padding='same', 'valid', or any integer (defaults to 'same')\"\n",
    "    assert not (kernel_size and ks), 'use kernel_size or ks but not both simultaneously'\n",
    "    assert kernel_size is not None or ks is not None, 'you need to pass a ks'\n",
    "    kernel_size = kernel_size or ks\n",
    "    if padding == 'same': \n",
    "        conv = Conv2dSame(ni, nf, kernel_size, stride=stride, dilation=dilation, **kwargs)\n",
    "    elif padding == 'valid': conv = nn.Conv2d(ni, nf, kernel_size, stride=stride, padding=0, dilation=dilation, **kwargs)\n",
    "    else: conv = nn.Conv2d(ni, nf, kernel_size, stride=stride, padding=padding, dilation=dilation, **kwargs)\n",
    "    init_linear(conv, None, init=init, bias_std=bias_std)\n",
    "    return conv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = 2\n",
    "c_in = 3\n",
    "c_out = 5\n",
    "h = 16\n",
    "w = 20\n",
    "t = torch.rand(bs, c_in, h, w)\n",
    "test_eq(Conv2dSame(c_in, c_out, ks=3, stride=1, dilation=1, bias=False)(t).shape, (bs, c_out, h, w))\n",
    "test_eq(Conv2dSame(c_in, c_out, ks=(3, 1), stride=1, dilation=1, bias=False)(t).shape, (bs, c_out, h, w))\n",
    "test_eq(Conv2dSame(c_in, c_out, ks=3, stride=(1, 1), dilation=(2, 2), bias=False)(t).shape, (bs, c_out, h, w))\n",
    "test_eq(Conv2dSame(c_in, c_out, ks=3, stride=(2, 2), dilation=(1, 1), bias=False)(t).shape, (bs, c_out, h//2, w//2))\n",
    "test_eq(Conv2dSame(c_in, c_out, ks=3, stride=(2, 2), dilation=(2, 2), bias=False)(t).shape, (bs, c_out, h//2, w//2))\n",
    "test_eq(Conv2d(c_in, c_out, ks=3, padding='same', stride=1, dilation=1, bias=False)(t).shape, (bs, c_out, h, w))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class CausalConv1d(torch.nn.Conv1d):\n",
    "    def __init__(self, ni, nf, ks, stride=1, dilation=1, groups=1, bias=True):\n",
    "        super(CausalConv1d, self).__init__(ni, nf, kernel_size=ks, stride=stride, padding=0, dilation=dilation, groups=groups, bias=bias)\n",
    "        self.__padding = (ks - 1) * dilation\n",
    "    def forward(self, input):\n",
    "        return super(CausalConv1d, self).forward(F.pad(input, (self.__padding, 0)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@delegates(nn.Conv1d)\n",
    "def Conv1d(ni, nf, kernel_size=None, ks=None, stride=1, padding='same', dilation=1, init='auto', bias_std=0.01, **kwargs):\n",
    "    \"conv1d layer with padding='same', 'causal', 'valid', or any integer (defaults to 'same')\"\n",
    "    assert not (kernel_size and ks), 'use kernel_size or ks but not both simultaneously'\n",
    "    assert kernel_size is not None or ks is not None, 'you need to pass a ks'\n",
    "    kernel_size = kernel_size or ks\n",
    "    if padding == 'same': \n",
    "        if kernel_size%2==1: \n",
    "            conv = nn.Conv1d(ni, nf, kernel_size, stride=stride, padding=kernel_size//2 * dilation, dilation=dilation, **kwargs)\n",
    "        else:\n",
    "            conv = SameConv1d(ni, nf, kernel_size, stride=stride, dilation=dilation, **kwargs)\n",
    "    elif padding == 'causal': conv = CausalConv1d(ni, nf, kernel_size, stride=stride, dilation=dilation, **kwargs)\n",
    "    elif padding == 'valid': conv = nn.Conv1d(ni, nf, kernel_size, stride=stride, padding=0, dilation=dilation, **kwargs)\n",
    "    else: conv = nn.Conv1d(ni, nf, kernel_size, stride=stride, padding=padding, dilation=dilation, **kwargs)\n",
    "    init_linear(conv, None, init=init, bias_std=bias_std)\n",
    "    return conv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = 2\n",
    "c_in = 3\n",
    "c_out = 5\n",
    "seq_len = 512\n",
    "t = torch.rand(bs, c_in, seq_len)\n",
    "dilation = 1\n",
    "test_eq(CausalConv1d(c_in, c_out, ks=3, dilation=dilation)(t).shape, Conv1d(c_in, c_out, ks=3, padding=\"same\", dilation=dilation)(t).shape)\n",
    "dilation = 2\n",
    "test_eq(CausalConv1d(c_in, c_out, ks=3, dilation=dilation)(t).shape, Conv1d(c_in, c_out, ks=3, padding=\"same\", dilation=dilation)(t).shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = 2\n",
    "ni = 3\n",
    "nf = 5\n",
    "seq_len = 6\n",
    "ks = 3\n",
    "t = torch.rand(bs, c_in, seq_len)\n",
    "test_eq(Conv1d(ni, nf, ks, padding=0)(t).shape, (bs, c_out, seq_len - (2 * (ks//2))))\n",
    "test_eq(Conv1d(ni, nf, ks, padding='valid')(t).shape, (bs, c_out, seq_len - (2 * (ks//2))))\n",
    "test_eq(Conv1d(ni, nf, ks, padding='same')(t).shape, (bs, c_out, seq_len))\n",
    "test_eq(Conv1d(ni, nf, ks, padding='causal')(t).shape, (bs, c_out, seq_len))\n",
    "test_error('use kernel_size or ks but not both simultaneously', Conv1d, ni, nf, kernel_size=3, ks=3)\n",
    "test_error('you need to pass a ks', Conv1d, ni, nf)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Conv1d(3, 5, kernel_size=(3,), stride=(1,), padding=(1,))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "conv = Conv1d(ni, nf, ks, padding='same')\n",
    "init_linear(conv, None, init='auto', bias_std=.01)\n",
    "conv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "CausalConv1d(3, 5, kernel_size=(3,), stride=(1,))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "conv = Conv1d(ni, nf, ks, padding='causal')\n",
    "init_linear(conv, None, init='auto', bias_std=.01)\n",
    "conv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Conv1d(3, 5, kernel_size=(3,), stride=(1,))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "conv = Conv1d(ni, nf, ks, padding='valid')\n",
    "init_linear(conv, None, init='auto', bias_std=.01)\n",
    "weight_norm(conv)\n",
    "conv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Conv1d(3, 5, kernel_size=(3,), stride=(1,))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "conv = Conv1d(ni, nf, ks, padding=0)\n",
    "init_linear(conv, None, init='auto', bias_std=.01)\n",
    "weight_norm(conv)\n",
    "conv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class SeparableConv1d(Module):\n",
    "    def __init__(self, ni, nf, ks, stride=1, padding='same', dilation=1, bias=True, bias_std=0.01):\n",
    "        self.depthwise_conv = Conv1d(ni, ni, ks, stride=stride, padding=padding, dilation=dilation, groups=ni, bias=bias)\n",
    "        self.pointwise_conv = nn.Conv1d(ni, nf, 1, stride=1, padding=0, dilation=1, groups=1, bias=bias)\n",
    "        if bias:\n",
    "            if bias_std != 0: \n",
    "                normal_(self.depthwise_conv.bias, 0, bias_std)\n",
    "                normal_(self.pointwise_conv.bias, 0, bias_std)\n",
    "            else: \n",
    "                self.depthwise_conv.bias.data.zero_()\n",
    "                self.pointwise_conv.bias.data.zero_()\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.depthwise_conv(x)\n",
    "        x = self.pointwise_conv(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = 64\n",
    "c_in = 6\n",
    "c_out = 5\n",
    "seq_len = 512\n",
    "t = torch.rand(bs, c_in, seq_len)\n",
    "test_eq(SeparableConv1d(c_in, c_out, 3)(t).shape, (bs, c_out, seq_len))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class AddCoords1d(Module):\n",
    "    \"\"\"Add coordinates to ease position identification without modifying mean and std\"\"\"\n",
    "    def forward(self, x):\n",
    "        bs, _, seq_len = x.shape\n",
    "        cc = torch.linspace(-1,1,x.shape[-1], device=x.device).repeat(bs, 1, 1)\n",
    "        cc = (cc - cc.mean()) / cc.std()\n",
    "        x = torch.cat([x, cc], dim=1)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = 2\n",
    "c_in = 3\n",
    "c_out = 5\n",
    "seq_len = 50\n",
    "\n",
    "t = torch.rand(bs, c_in, seq_len)\n",
    "t = (t - t.mean()) / t.std()\n",
    "test_eq(AddCoords1d()(t).shape, (bs, c_in + 1, seq_len))\n",
    "new_t = AddCoords1d()(t)\n",
    "test_close(new_t.mean(),0, 1e-2)\n",
    "test_close(new_t.std(), 1, 1e-2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class ConvBlock(nn.Sequential):\n",
    "    \"Create a sequence of conv1d (`ni` to `nf`), activation (if `act_cls`) and `norm_type` layers.\"\n",
    "    def __init__(self, ni, nf, kernel_size=None, ks=3, stride=1, padding='same', bias=None, bias_std=0.01, norm='Batch', zero_norm=False, bn_1st=True,\n",
    "                 act=nn.ReLU, act_kwargs={}, init='auto', dropout=0., xtra=None, coord=False, separable=False,  **kwargs):\n",
    "        kernel_size = kernel_size or ks\n",
    "        ndim = 1\n",
    "        layers = [AddCoords1d()] if coord else []\n",
    "        norm_type = getattr(NormType,f\"{snake2camel(norm)}{'Zero' if zero_norm else ''}\") if norm is not None else None\n",
    "        bn = norm_type in (NormType.Batch, NormType.BatchZero)\n",
    "        inn = norm_type in (NormType.Instance, NormType.InstanceZero)\n",
    "        if bias is None: bias = not (bn or inn)\n",
    "        if separable: conv = SeparableConv1d(ni + coord, nf, ks=kernel_size, bias=bias, stride=stride, padding=padding, **kwargs)\n",
    "        else: conv = Conv1d(ni + coord, nf, ks=kernel_size, bias=bias, stride=stride, padding=padding, **kwargs)\n",
    "        act = None if act is None else act(**act_kwargs)\n",
    "        if not separable: init_linear(conv, act, init=init, bias_std=bias_std)\n",
    "        if   norm_type==NormType.Weight:   conv = weight_norm(conv)\n",
    "        elif norm_type==NormType.Spectral: conv = spectral_norm(conv)\n",
    "        layers += [conv]\n",
    "        act_bn = []        \n",
    "        if act is not None: act_bn.append(act)\n",
    "        if bn: act_bn.append(BatchNorm(nf, norm_type=norm_type, ndim=ndim))\n",
    "        if inn: act_bn.append(InstanceNorm(nf, norm_type=norm_type, ndim=ndim))\n",
    "        if bn_1st: act_bn.reverse()\n",
    "        if dropout: layers += [nn.Dropout(dropout)]\n",
    "        layers += act_bn\n",
    "        if xtra: layers.append(xtra)\n",
    "        super().__init__(*layers)     \n",
    "                            \n",
    "Conv = named_partial('Conv', ConvBlock, norm=None, act=None)\n",
    "ConvBN = named_partial('ConvBN', ConvBlock, norm='Batch', act=None)\n",
    "CoordConv = named_partial('CoordConv', ConvBlock, norm=None, act=None, coord=True)\n",
    "SepConv = named_partial('SepConv', ConvBlock, norm=None, act=None, separable=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class ResBlock1dPlus(Module):\n",
    "    \"Resnet block from `ni` to `nh` with `stride`\"\n",
    "    @delegates(ConvLayer.__init__)\n",
    "    def __init__(self, expansion, ni, nf, coord=False, stride=1, groups=1, reduction=None, nh1=None, nh2=None, dw=False, g2=1,\n",
    "                 sa=False, sym=False, norm='Batch', zero_norm=True, act_cls=defaults.activation, ks=3,\n",
    "                 pool=AvgPool, pool_first=True, **kwargs):\n",
    "        if nh2 is None: nh2 = nf\n",
    "        if nh1 is None: nh1 = nh2\n",
    "        nf,ni = nf*expansion,ni*expansion\n",
    "        k0 = dict(norm=norm, zero_norm=False, act=act_cls, **kwargs)\n",
    "        k1 = dict(norm=norm, zero_norm=zero_norm, act=None, **kwargs)\n",
    "        convpath  = [ConvBlock(ni,  nh2, ks, coord=coord, stride=stride, groups=ni if dw else groups, **k0),\n",
    "                     ConvBlock(nh2,  nf, ks, coord=coord, groups=g2, **k1)\n",
    "        ] if expansion == 1 else [\n",
    "                     ConvBlock(ni,  nh1, 1, coord=coord, **k0),\n",
    "                     ConvBlock(nh1, nh2, ks, coord=coord, stride=stride, groups=nh1 if dw else groups, **k0),\n",
    "                     ConvBlock(nh2,  nf, 1, coord=coord, groups=g2, **k1)]\n",
    "        if reduction: convpath.append(SEModule(nf, reduction=reduction, act_cls=act_cls))\n",
    "        if sa: convpath.append(SimpleSelfAttention(nf,ks=1,sym=sym))\n",
    "        self.convpath = nn.Sequential(*convpath)\n",
    "        idpath = []\n",
    "        if ni!=nf: idpath.append(ConvBlock(ni, nf, 1, coord=coord, act=None, **kwargs))\n",
    "        if stride!=1: idpath.insert((1,0)[pool_first], pool(stride, ndim=1, ceil_mode=True))\n",
    "        self.idpath = nn.Sequential(*idpath)\n",
    "        self.act = defaults.activation(inplace=True) if act_cls is defaults.activation else act_cls()\n",
    "\n",
    "    def forward(self, x): return self.act(self.convpath(x) + self.idpath(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def SEModule1d(ni, reduction=16, act=nn.ReLU, act_kwargs={}):\n",
    "    \"Squeeze and excitation module for 1d\"\n",
    "    nf = math.ceil(ni//reduction/8)*8\n",
    "    assert nf != 0, 'nf cannot be 0'\n",
    "    return SequentialEx(nn.AdaptiveAvgPool1d(1), \n",
    "                        ConvBlock(ni, nf, ks=1, norm=None, act=act, act_kwargs=act_kwargs),\n",
    "                        ConvBlock(nf, ni, ks=1, norm=None, act=nn.Sigmoid), ProdLayer())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = torch.rand(8, 32, 12)\n",
    "test_eq(SEModule1d(t.shape[1], 16, act=nn.ReLU, act_kwargs={})(t).shape, t.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def Norm(nf, ndim=1, norm='Batch', zero_norm=False, init=True, **kwargs):\n",
    "    \"Norm layer with `nf` features and `ndim` with auto init.\"\n",
    "    assert 1 <= ndim <= 3\n",
    "    nl = getattr(nn, f\"{snake2camel(norm)}Norm{ndim}d\")(nf, **kwargs)\n",
    "    if nl.affine and init:\n",
    "        nl.bias.data.fill_(1e-3)\n",
    "        nl.weight.data.fill_(0. if zero_norm else 1.)\n",
    "    return nl\n",
    "\n",
    "BN1d = partial(Norm, ndim=1, norm='Batch')\n",
    "IN1d = partial(Norm, ndim=1, norm='Instance')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = 2\n",
    "ni = 3\n",
    "nf = 5\n",
    "sl = 4\n",
    "ks = 5\n",
    "\n",
    "t = torch.rand(bs, ni, sl)\n",
    "test_eq(ConvBlock(ni, nf, ks)(t).shape, (bs, nf, sl))\n",
    "test_eq(ConvBlock(ni, nf, ks, padding='causal')(t).shape, (bs, nf, sl))\n",
    "test_eq(ConvBlock(ni, nf, ks, coord=True)(t).shape, (bs, nf, sl))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(BN1d(ni)(t).shape, (bs, ni, sl))\n",
    "test_eq(BN1d(ni).weight.data.mean().item(), 1.)\n",
    "test_eq(BN1d(ni, zero_norm=True).weight.data.mean().item(), 0.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ConvBlock(\n",
       "  (0): AddCoords1d()\n",
       "  (1): Conv1d(4, 5, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n",
       "  (2): BatchNorm1d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (3): Swish()\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_eq(ConvBlock(ni, nf, ks, norm='batch', zero_norm=True)[1].weight.data.unique().item(), 0)\n",
    "test_ne(ConvBlock(ni, nf, ks, norm='batch', zero_norm=False)[1].weight.data.unique().item(), 0)\n",
    "test_eq(ConvBlock(ni, nf, ks, bias=False)[0].bias, None)\n",
    "ConvBlock(ni, nf, ks, act=Swish, coord=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class LinLnDrop(nn.Sequential):\n",
    "    \"Module grouping `LayerNorm1d`, `Dropout` and `Linear` layers\"\n",
    "    def __init__(self, n_in, n_out, ln=True, p=0., act=None, lin_first=False):\n",
    "        layers = [nn.LayerNorm(n_out if lin_first else n_in)] if ln else []\n",
    "        if p != 0: layers.append(nn.Dropout(p))\n",
    "        lin = [nn.Linear(n_in, n_out, bias=not ln)]\n",
    "        if act is not None: lin.append(act)\n",
    "        layers = lin+layers if lin_first else layers+lin\n",
    "        super().__init__(*layers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "LinLnDrop(\n",
       "  (0): LayerNorm((2,), eps=1e-05, elementwise_affine=True)\n",
       "  (1): Dropout(p=0.5, inplace=False)\n",
       "  (2): Linear(in_features=2, out_features=3, bias=False)\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "LinLnDrop(2, 3, p=.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class LambdaPlus(Module):\n",
    "    def __init__(self, func, *args, **kwargs): self.func,self.args,self.kwargs=func,args,kwargs\n",
    "    def forward(self, x): return self.func(x, *self.args, **self.kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class Squeeze(Module):\n",
    "    def __init__(self, dim=-1): self.dim = dim\n",
    "    def forward(self, x): return x.squeeze(dim=self.dim)\n",
    "    def __repr__(self): return f'{self.__class__.__name__}(dim={self.dim})'\n",
    "\n",
    "\n",
    "class Unsqueeze(Module):\n",
    "    def __init__(self, dim=-1): self.dim = dim\n",
    "    def forward(self, x): return x.unsqueeze(dim=self.dim)\n",
    "    def __repr__(self): return f'{self.__class__.__name__}(dim={self.dim})'\n",
    "\n",
    "\n",
    "class Add(Module):\n",
    "    def forward(self, x, y): return x.add(y)\n",
    "    def __repr__(self): return f'{self.__class__.__name__}'\n",
    "\n",
    "\n",
    "class Concat(Module):\n",
    "    def __init__(self, dim=1): self.dim = dim\n",
    "    def forward(self, *x): return torch.cat(*x, dim=self.dim)\n",
    "    def __repr__(self): return f'{self.__class__.__name__}(dim={self.dim})'\n",
    "\n",
    "\n",
    "class Permute(Module):\n",
    "    def __init__(self, *dims): self.dims = dims\n",
    "    def forward(self, x): return x.permute(self.dims)\n",
    "    def __repr__(self): return f\"{self.__class__.__name__}(dims={', '.join([str(d) for d in self.dims])})\"\n",
    "    \n",
    "    \n",
    "class Transpose(Module):\n",
    "    def __init__(self, *dims, contiguous=False): self.dims, self.contiguous = dims, contiguous\n",
    "    def forward(self, x): \n",
    "        if self.contiguous: return x.transpose(*self.dims).contiguous()\n",
    "        else: return x.transpose(*self.dims)\n",
    "    def __repr__(self): \n",
    "        if self.contiguous: return f\"{self.__class__.__name__}(dims={', '.join([str(d) for d in self.dims])}).contiguous()\"\n",
    "        else: return f\"{self.__class__.__name__}({', '.join([str(d) for d in self.dims])})\"\n",
    "    \n",
    "    \n",
    "class View(Module):\n",
    "    def __init__(self, *shape): self.shape = shape\n",
    "    def forward(self, x): return x.view(x.shape[0], *self.shape)\n",
    "    def __repr__(self): return f\"{self.__class__.__name__}({', '.join(['bs'] + [str(s) for s in self.shape])})\"\n",
    "    \n",
    "    \n",
    "class Reshape(Module):\n",
    "    def __init__(self, *shape): self.shape = shape\n",
    "    def forward(self, x): return x.reshape(x.shape[0], *self.shape)\n",
    "    def __repr__(self): return f\"{self.__class__.__name__}({', '.join(['bs'] + [str(s) for s in self.shape])})\"\n",
    "    \n",
    "    \n",
    "class Max(Module):\n",
    "    def __init__(self, dim=None, keepdim=False): self.dim, self.keepdim = dim, keepdim\n",
    "    def forward(self, x): return x.max(self.dim, keepdim=self.keepdim)[0]\n",
    "    def __repr__(self): return f'{self.__class__.__name__}(dim={self.dim}, keepdim={self.keepdim})'\n",
    "\n",
    "    \n",
    "class LastStep(Module):\n",
    "    def forward(self, x): return x[..., -1]\n",
    "    def __repr__(self): return f'{self.__class__.__name__}()'\n",
    "    \n",
    "    \n",
    "class SoftMax(Module):\n",
    "    \"SoftMax layer\"\n",
    "    def __init__(self, dim=-1):\n",
    "        self.dim = dim\n",
    "    def forward(self, x):\n",
    "        return F.softmax(x, dim=self.dim)\n",
    "    def __repr__(self): return f'{self.__class__.__name__}(dim={self.dim})' \n",
    "    \n",
    "\n",
    "class Clamp(Module):\n",
    "    def __init__(self, min=None, max=None):\n",
    "        self.min, self.max = min, max\n",
    "    def forward(self, x):\n",
    "        return x.clamp(min=self.min, max=self.max)\n",
    "    def __repr__(self): return f'{self.__class__.__name__}(min={self.min}, max={self.max})' \n",
    "    \n",
    "    \n",
    "class Clip(Module):\n",
    "    def __init__(self, min=None, max=None):\n",
    "        self.min, self.max = min, max\n",
    "        \n",
    "    def forward(self, x):\n",
    "        if self.min is not None:\n",
    "            x = torch.maximum(x, self.min)\n",
    "        if self.max is not None:\n",
    "            x = torch.minimum(x, self.max)\n",
    "        return x\n",
    "    def __repr__(self): return f'{self.__class__.__name__}()' \n",
    "    \n",
    "    \n",
    "class ReZero(Module):\n",
    "    def __init__(self, module):\n",
    "        self.module = module\n",
    "        self.alpha = nn.Parameter(torch.zeros(1))\n",
    "    def forward(self, x):\n",
    "        return x + self.alpha * self.module(x)\n",
    "    \n",
    "    \n",
    "Noop = nn.Sequential()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Transpose(1, 2),\n",
       " Permute(dims=0, 2, 1),\n",
       " View(bs, -1, 2, 10),\n",
       " Transpose(dims=1, 2).contiguous(),\n",
       " Reshape(bs, -1, 2, 10),\n",
       " Sequential())"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bs = 2\n",
    "nf = 5\n",
    "sl = 4\n",
    "\n",
    "t = torch.rand(bs, nf, sl)\n",
    "test_eq(Permute(0,2,1)(t).shape, (bs, sl, nf))\n",
    "test_eq(Max(1)(t).shape, (bs, sl))\n",
    "test_eq(Transpose(1,2)(t).shape, (bs, sl, nf))\n",
    "test_eq(Transpose(1,2, contiguous=True)(t).shape, (bs, sl, nf))\n",
    "test_eq(View(-1, 2, 10)(t).shape, (bs, 1, 2, 10))\n",
    "test_eq(Reshape(-1, 2, 10)(t).shape, (bs, 1, 2, 10))\n",
    "Transpose(1,2), Permute(0,2,1), View(-1, 2, 10), Transpose(1,2, contiguous=True), Reshape(-1, 2, 10), Noop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "\n",
    "class DropPath(nn.Module):\n",
    "    \"\"\"Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).\n",
    "\n",
    "    It's similar to Dropout but it drops individual connections instead of nodes.\n",
    "    Original code in https://github.com/rwightman/pytorch-image-models (timm library)\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, p=None):\n",
    "        super().__init__()\n",
    "        self.p = p\n",
    "\n",
    "    def forward(self, x):\n",
    "        if self.p == 0. or not self.training: return x\n",
    "        keep_prob = 1 - self.p\n",
    "        shape = (x.shape[0],) + (1,) * (x.ndim - 1)\n",
    "        random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)\n",
    "        random_tensor.floor_()\n",
    "        output = x.div(keep_prob) * random_tensor\n",
    "#         output = x.div(random_tensor.mean()) * random_tensor # divide by the actual mean to mantain the input mean?\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = torch.ones(100,2,3)\n",
    "test_eq(DropPath(0.)(t), t)\n",
    "assert DropPath(0.5)(t).max() >= 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class Sharpen(Module):\n",
    "    \"This is used to increase confidence in predictions - MixMatch paper\"\n",
    "    def __init__(self, T=.5): self.T = T\n",
    "    def forward(self, x):\n",
    "        x = x**(1. / self.T)\n",
    "        return x / x.sum(dim=1, keepdims=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAmxElEQVR4nO3de7xVc/7H8denK0pJhaQ6mYpJRM4vucy4hXLLncYtzDQ1cp9MhEluYxhDM6JcphAJg1QuTWoaUTqpUMQpXaVSqXQ/zvf3x2efOh3nss85e599ez8fj22vvdbaa31W+/js7/6u78VCCIiISOqrlugAREQkNpTQRUTShBK6iEiaUEIXEUkTSugiImmiRqJO3KhRo5CVlZWo04uIpKQZM2Z8H0JoXNy2hCX0rKwscnJyEnV6EZGUZGaLStqmKhcRkTShhC4ikiaU0EVE0oQSuohImlBCFxFJE2UmdDN71sxWmtnnJWw3MxtkZrlm9qmZdYh9mCIiUpZoSujDgC6lbO8KtI48egJPVD4sEREprzLboYcQJptZVim7dAOeCz4O71Qz28vMmoQQlscqSBGRWPjpJ9iyxR+bN+98Lrxc9DkvD0Io3yM/v9Dr/EBYv4HwwzrCunWEdRs4q0dD/u/SNjG/vlh0LGoKLCn0emlk3c8Supn1xEvxNG/ePAanFhHZaf16+PxzmDkTZs2CFSvgu+/8sXIlbN2aiKgMqBd5NANg/4YTkzahRy2EMBQYCpCdna2ZNUSkwkKASZNg2jT4+GP48kt/FMzZ07AhNG0KTZpA27awzz5Qty7sthvsvvvPn4tbt9tuUKMGmBV5rP4ee2s0NnsW9vE07OuvsO1bMQLVCNgBTXd9/OJArGUW1rwZNGsG9U+Iy79JLBL6Mgq+dtwBkXUiIjG1fTu88QaMHetJ/IsvfH3r1nDIIXDRRXDkkdCuHWRlefKNmY0b4d//hjff9EdeHtSrB8ceC2f8Dn7xC9h/fzjuOGjQIIYnjl4sEvpooI+ZjQSOAtap/lxEYmXJEhg3DqZMgdGjYd06aNQIsrPh+us9ie+9dxwDmDULhgyBV16B1av9ZL17w29/C4ceGuNvjcopM6Gb2UvACUAjM1sK/BmoCRBCeBIYB5wO5AKbgKviFayIZI7ly+GGGzyPgleZnHEGXHopdOkC1eLZiyYEr4h/4gl45hmvfzntNLj5ZujUyethklA0rVy6l7E9ANfGLCIRyVghwFtvwfDhXirPz4e+feHqq+Ggg6qgMJyXB6+9BgMGeIV8zZpw+eXwyCNeKZ/kkvNrRkQyzvjx0K8ffPIJ7Lefl8T79YNWrarg5Dk5MGwYjBoFq1ZB+/ZezXLBBXGuz4ktJXQRSbiRI+GKK7xVytChcOWVUKtWnE8aArz4ov8cGD/em7ScdRZccgmcfXbSVquUJvUiFpG0EQL86U/w0ENw+OF+07NZszLfVvmTTpwITz8NL73kifyOO+CWW2CvveJ88vhSQheRhJg3z6unp0/3RiOPPeZV1nE1fz706gX/+Y8n8j594O9/T8nSeHHS4ypEJKV8+y107uxNEP/2N7jxxipotfLyy3DrrbB2Ldx5py/XrRvHk1Y9JXQRqVJTp3ohecUKL523bx/nE06bBtdd5ydr3Rreecc7A6UhjYcuIlVm+HA45hhP5sOGxTmZ5+fDxRd7u/HcXBg82LuWpmkyB5XQRaQK/PijN0F8/HE48UR4/XWoXz/OJ334YW+GeO21MHBgSjU/rCiV0EUk7nr29GTes6fXeMQ9mb/0kn+DdOoEf/1rRiRzUAldROJs0iTPr716eU/6uJs82ZvPHH20t2bZffcqOGlyUAldROJmzhzvp7P//l4DEndffQVXXQWNG3uj9gxK5qCELiJxsn69d7jcsgXefhvq1InzCf/xDx83d8kS726aAmOvxJqqXEQk5n78EU45BRYvhjFj4LDD4nzCN9/0xuwdO3r9TlZWnE+YnJTQRSTmbrnFJ6B46SUfdTaupk/35oktWvi3RwaWzAuoykVEYuqbb7zGo2Ccq7jautVnudh7b+9AlMHJHFRCF5EY+vFHuOYaHyZl8OAqOGHv3t719F//8huhGU4JXURiIj/fWwtOnOjjsxxwQJxPOGyYJ/Jbb4UePeJ8stSgKhcRqbRNm3wuiDfegHvu8Zna4mrcOB+f5eij4f7743yy1KGELiKVdtll3p2/Tx/o3z/OJ/vrX31y0awsn3C0evU4nzB1KKGLSKV8+KEn8759vSl4XOf9HDLEZ8Q45hhvRtO0aRxPlnqU0EWkwrZvhwsv9DlAb789zicbM8arWdq18wFhMqwXaDR0U1REKuzhh32yitdfj/PsbW+95d1ODzsM3n0X9twzjidLXUroIlIhixb5DdDTTvNcGzeff+6tWI44AqZMUcm8FKpyEZEKue8+79fzj3/Ecfq422/3Unl+vk8hp2ReKiV0ESm3sWPhqad8fPPWreN0kocfhgcegHPP9Rml43ai9GEhhIScODs7O+Tk5CTk3CJScSFAy5Y+muLixXGaZ/mjj+Ckk+DUU+G116CGaocLmNmMEEJ2cdtUQheRcnnvPa8/79EjTsl8xgxvlrj//v4zQMk8akroIlIu990HtWvDnXfG4eCbN8M550C9et6aZZ994nCS9KWvPhGJ2syZ8L//wW23QYMGMT54fr4PtrV0KbzwArRqFeMTpD8ldBGJ2j33eEOTvn1jfOD8fK8vnzABzjsPLr00xifIDKpyEZGoTJzoHYh69IhD6XzMGE/mt93mzROlQqJK6GbWxczmmVmumfUrZntzM5toZjPN7FMzOz32oYpIIg0b5uO0PPBAjA88fbqP7vXLX3pC103QCiszoZtZdeBxoCvQFuhuZm2L7HYHMCqEcARwCVAVQ9uLSBWZPRuefx5uugnq14/xwe+4w2+GvvOOuvRXUjQl9I5AbghhQQhhGzAS6FZknwDUiyzXB76NXYgikmhPP+2j1Pb72e/zSlq40OtyevaE5s1jfPDME01CbwosKfR6aWRdYQOAy8xsKTAOuK64A5lZTzPLMbOcVatWVSBcEalq338PI0bACSfEeJa3hQvh2GO9DeRll8XwwJkrVjdFuwPDQggHAKcDz5vZz44dQhgaQsgOIWQ31vx/Iinhqadg7Vr4859jeNCC+erWrIFJk3zmIam0aBL6MqBZodcHRNYVdg0wCiCE8BGwG9AoFgGKSGJNmgRt28Jxx8XogCtXetPEDz6ARx+FI4+M0YElmoQ+HWhtZi3NrBZ+03N0kX0WAycDmNkv8YSuOhWRFLdmDUyeDCefHKMDLl0Kv/61zwk6cKDXnUvMlNk+KISQZ2Z9gHeB6sCzIYQ5ZjYQyAkhjAZuAZ4ys5vwG6Q9QqJG/RKRmBkxArZsgauuitEBr7sOvv4a3nwTzjwzRgeVAhptUUSKtXEjHH64J/QlS8rcvWz/+5+Xzvv08UHUpUJKG21RLfhFpFhDhkBurhemK23TJvjd77yL6YABMTigFEdd/0WkWO+8A1lZMZhebtEib544bx6MGgUNG8YiPCmGErqI/My2bd665dxzY3Cwyy+HTz+F4cOhc+cYHFBKoioXEfmZadNg+3Y46qhKHuiDD7zu/J574IorYhKblEwldBH5mbfe8oG4unat5IGeeQb22gtuvjkWYUkZlNBFZBch+Ai2xxzjEwdV2IwZ8Mor/q2wxx4xi09KpoQuIruYMMEnf77ggkocZMECvxFav75atVQhJXQR2SEE6N8f6tSB3/++Ege6807YutXnBW3TJmbxSemU0EVkh9xc+PhjuPFGn2quQoYPhxdf9KmN2rWLYXRSFiV0Ednh8cf9+fzzK3iAhQu9J+ixx/owjVKllNBFBIC8PBg8GH7zGzjiiAoeZNAg7xX67LOaSi4BlNBFBIAffvC25506VfAA330Hf/87dOumevMEUUIXEcCHygXYe+8KHuDCC/25d++YxCPlp4QuIoD3DgVoWnSCyWi8+ab3Cr3pJjjllJjGJdFTQhcRAIYOhYMOgl/9qgJvfuABb3N+990xj0uip4QuImzaBFOn+sxE1auX4435+XD//V68v/lm2HPPuMUoZdNtaBFh7lxv5XLsseV4Uwjegej+++Hii6Fv37jFJ9FRQheRHZNYnHhiOd7Uty/87W8+RsBLL/loXpJQqnIRET78ENq3hyZNonzD2LGezH//ex/JS8k8KSihi2S49et9yPKTT47yDStXQq9e0KoVPPooVFMaSRaqchHJcO+/7x2KTj89ip3XrYMzzoDvv/dmirvtFvf4JHpK6CIZbtYsrzGJqofoHXdATo7PDXrkkfEOTcpJv5VEMtjmzT444qGH+pC5pVq3zsdo6dFjZ69QSSoqoYtksLvv9gESR44sY8f8fOjZ078Brr++KkKTClAJXSSDvfEGnHqqNyMvUV4eXH65V7M8+GAlhmKUeFNCF8lQubkwb14UN0P79vUJK26/Hf74xyqJTSpGCV0kQ916q89KdM45pey0aJE3TezSBe69V+3Nk5wSukgG2rTJq1t694YWLUrYKQSvaqle3Sd6VjJPeropKpKBZs70fH3MMaXstHix9zh68EE46qgqi00qTiV0kQwTgjdYadiwjKFyp0zx53KN2CWJFFVCN7MuZjbPzHLNrF8J+1xkZnPNbI6ZvRjbMEUkVpYs8dEV//xn2GefEnbavNmrWVq0gI4dqzI8qYQyq1zMrDrwOHAKsBSYbmajQwhzC+3TGrgNODaEsNbMSvozEZEE++QTf+7QoZSdRoyAr7+GMWOgZs0qiUsqL5oSekcgN4SwIISwDRgJdCuyz++Ax0MIawFCCCtjG6aIxMonn/h4WiX23A8BXngBWraMcoAXSRbRJPSmwJJCr5dG1hXWBmhjZlPMbKqZdSnuQGbW08xyzCxn1apVFYtYRCplzhxo3bqUcbXefhv++1+48Ua1bEkxsbopWgNoDZwAdAeeMrO9iu4UQhgaQsgOIWQ3btw4RqcWkWht3+4NV0qtbnnlFZ8ftFevKotLYiOahL4MaFbo9QGRdYUtBUaHELaHEL4BvsITvIgkkf/9D1atKqWrf14ejB4NZ54JtWpVaWxSedEk9OlAazNraWa1gEuA0UX2eQMvnWNmjfAqmAWxC1NEYmFJpPK0XbsSdvjTn2DNGp9WTlJOmQk9hJAH9AHeBb4ARoUQ5pjZQDM7O7Lbu8BqM5sLTAT6hhBWxytoEamYNWv8uWHDYjbm5MAjj8Bll0G3ou0eJBVE1VM0hDAOGFdk3V2FlgNwc+QhIklqwgQfv6VevWI2Pvkk7Lkn/POfuhmaotRTVCRDrF8P48bBb35TzDSgs2bBsGHQtavfEJWUpIQukiFeecWbmF9xRZENy5ZB9+6w117w8MOJCE1iRINziWSIV1+FVq2KjN+yaROccIIPkztmDDRrVtLbJQWohC6SAdav9/rzc84pUj3+wgs+08WTT/rURZLSlNBFMsDbb3unol0ms9iyBe66ywffuuqqRIUmMaQqF5EMMGgQ7LsvdOpUaOWoUbBihZfS1aolLaiELpLmVq6EDz+EPn188iEAfvrJS+cHHQQnn5zQ+CR2VEIXSXPTp/vz8ccXWvnSS34jdPBglc7TiEroImnu44+93fmOAbm+/BKuvRbatPGpiyRtKKGLpLlp03zsljp18MG3rr7an8eMKVQHI+lACV0kjYXgJfQds8iNGAEffQR33OGDoktaUUIXSWOrVsHatXDYYXideZ8+0LYt9O2b6NAkDpTQRdLYyshkkPvtB/TrB1u3whtvQA21h0hHSugiaWzsWH9usupTGDnSOxCpqiVtmY98W/Wys7NDTk5OQs4tkgny8qBBAzj80DwmrWhL9Q0/wIIFULduokOTSjCzGSGE7OK26XeXSJrKzYUff4Tf1h1J9Y++hokTlczTnKpcRNLU22/7c4eJj8A11/ioipLWlNBF0tSzz0L7X/xIu7yZcNppiQ5HqoASukgaWroUPv8cLs1/Dtttt0IN0SWdKaGLpKGpU/35pG+e8eaKLVokNiCpEkroImnoq9FfAtC62Va45ZYERyNVRa1cRNJNCEx4eRWHVMun3v/GqmVLBlEJXSTNzP9gOZO2HcO5J29QVUuGUUIXSTPj/rmAfKpz9R9qJzoUqWJK6CJpZsyUBmTZQrLObp/oUKSKKaGLpJHt4yfxwbIszmrxKVZNMxFlGiV0kXQxfz6Pn/sfNlGHE+46vuz9Je0ooYukg3XroHNnRm48k1bNt9LtivqJjkgSQAldJB386U/MXbg70+hE9ytra2a5DKWELpLqXnsNhgzhg04+C9HZZyc4HkkYJXSRVLZuHfzhD9CuHV8ffQW1a0OHDokOShIlqoRuZl3MbJ6Z5ZpZv1L2O9/MgpkVO/i6iMRY//4+z9w//8mCRdXJyoJqKqZlrDI/ejOrDjwOdAXaAt3NrG0x++0J3ABMi3WQIlKMSZNg8GDo1g2OP56pU+HQQxMdlCRSNN/lHYHcEMKCEMI2YCTQrZj97gEeBLbEMD4RKcmjj8K++8KIEXz7LXz7LRx7bKKDkkSKJqE3BZYUer00sm4HM+sANAshjC3tQGbW08xyzCxn1apV5Q5WRCJmzIDRo+HSS6FOHSZP9tWdOiU2LEmsSte2mVk14BGgzDE6QwhDQwjZIYTsxo0bV/bUIplrbKTsdNttAEyYALvvrhuimS6ahL4MaFbo9QGRdQX2BNoBk8xsIdAJGK0boyJxsnIlPPMMtG8PDRsCkJMD//d/UKtWgmOThIomoU8HWptZSzOrBVwCjC7YGEJYF0JoFELICiFkAVOBs0MIOXGJWCTT/eUvsHixPwNTpsCsWXDKKYkNSxKvzIQeQsgD+gDvAl8Ao0IIc8xsoJmpC4NIVVqxAp5+Gs44Y8fEz+PHe1PF665LcGyScFHNWBRCGAeMK7LurhL2PaHyYYlIsf7yF9iyBQYM2LFq8mQ4/HCor+FbMp66IIikisWL4bHH4MQTIdtvUc2bBx98ACeckNjQJDkooYukimHDIARvfx4xfLiv0jzQAkroIqlh/Hh46CHo3Bl++csdq997z9ue779/AmOTpKGELpLstm2D3/0OGjSAZ5/dsXrsWO9fdOaZCYxNkkpUN0VFJIEefhgWLYI334RmO7uE3HOP9/y/8cbEhSbJRSV0kWQ2fz7ceSeccw6cddaO1Rs3wrRp8PvfQ+3aiQtPkosSukiy2rbN25rn58ODD4LtnPR55kx/PvzwxIQmyUkJXSRZDRrkJfSRI6FNm102DRwIe+wBRx+doNgkKSmhiySjvDx45BE4/ni4+OKfbZ45E7p3h/32S0BskrR0U1Qk2YQAd90Fy5f7BBZFfPklfP/9zwrtIiqhiySdIUPggQfg6qt9NqIiXn/dn3v0qNqwJPkpoYskk61b4d57feqhp5/e5UZogZdfhiZNYJ99EhCfJDUldJFksWULnH46LFsGN9xQbDJfswZmzy62Wl1ECV0kaQwaBO+/7/XmF15Y7C4FU82pd6gURwldJBmsXQv33QfHHQe9e5e42+zZXnDX3KFSHCV0kUTbtAl+9StYv957hZZi9mxo3Rrq1Kmi2CSlKKGLJNp558GcOT5P6Kmnlrjb6tVeI3PEEVUYm6QUJXSRRBo3Dt5919udX311qbuOGgXr1sFNN1VRbJJylNBFEmXtWrj2WmjRAvr3L3P3mTOhYUPo2LEKYpOUpJ6iIomwZo1PVrFoEUyaBLVqlfmWuXOhbdtiWzOKACqhi1S9ZcvgpJM8Q7/+Ovz612W+JYSdCV2kJCqhi1Sl5ct9RufvvvMJK047Laq3jR/vNTTt28c3PEltSugiVSU/33uC5uZ6hu7cOeq3DhkCNWrA5ZfHMT5JeapyEakKeXlwxRUwaxb8/e/lSubLlsFbb0GvXlC3bvxClNSnhC5SFW65BUaM8NYsN9xQrre+9hps3w59+sQpNkkbqnIRiacQfHqhQYO8nfk995S7mcpnn0GDBhr/XMqmhC4SLxs3wjXX+Hi33bvDY4+VO5lPngzDh0PXrmquKGVTlYtIvPTt68n8j3+E556rUAX4jTfCvvvC0KGxD0/SjxK6SDw89xw88QRcdBE89JA3USmnLVvgiy98JN19941DjJJ2lNBFYm3iRLj+ejjmGB9wq4Jee82TeteuMYxN0lpUCd3MupjZPDPLNbN+xWy/2czmmtmnZjbBzFrEPlSRFDBtmncW2msvr/yuRDvDp56Cpk29U6lINMpM6GZWHXgc6Aq0BbqbWdEOyDOB7BDCYcCrwF9jHahI0nvjDc++e+8NM2ZAq1YVPtT8+fDf/3rDmOrVYxeipLdoSugdgdwQwoIQwjZgJLDLVOQhhIkhhE2Rl1OBA2IbpkiS+/e/vb68SRP48EMfFrES7rjDE7l6hkp5RJPQmwJLCr1eGllXkmuAt4vbYGY9zSzHzHJWrVoVfZQiyWziRDj/fJ954oMP4MADK324kSO9L1Lr1jGKUTJCTG+KmtllQDbwUHHbQwhDQwjZIYTsxo0bx/LUIokRgvf8rFXLJ6rYb79KHW7RIjj7bGjZUhNZSPlF05ZqGdCs0OsDIut2YWadgf7A8SGErbEJTySJrV3r08d99hncf7/fCK2kYcO8P9Jnn1X6u0EyUDQl9OlAazNraWa1gEuA0YV3MLMjgCHA2SGElbEPUyTJfPwxZGf75BQDB3r9SCW9/TbcfbePrpuVVenDSQYqM6GHEPKAPsC7wBfAqBDCHDMbaGZnR3Z7CKgLvGJms8xsdAmHE0l9Dz4InTp5Ufqtt+DOO6Oacag0q1f76ACNGnkpXaQiouq+FkIYB4wrsu6uQsvRjwUqksrefBP69YNTToFXX4V69WJy2IEDYcMGH123efOYHFIykHqKikRr+HBvzXLwwd4MJUbJfMUKHyXgssvg0ENjckjJUEroItGYOxd69IB27WDCBO88FCMF451fdVXMDikZSgldpCzDh8Opp/ryK6/A/vvH9PAvvujtzaOYK1qkVEroIiX57jvve9+jh7chnDIl5j19Ro3yw55/PlTT/41SSfoTEilq82YYMACaNYN//cvHM//wQx89MYZWr/ZDH3KI3xQVqSzNWCRS2KxZPhvztGnejvCuu/wmaBz06eMTQI8YATVrxuUUkmGU0EXAS+Vduvicb/XqwQsvwKWXxu10Q4d6Q5nrroNf/Spup5EMo4QuAj6J8+TJ3lWzd2+I41hDn37qHUuPOw4efjhup5EMpIQumeu773wM89Gjvd/9iSd6r884zsa8fDl06wa77eY/AirZwVRkF0roknnWrfOi8UMPwdat3oLlzjt9EPI4JvMFC3z4l82bffKKFprXS2JMCV0yy5dfQufOfjeye3dP5AcfHNdEXmDAAB+g8aWXoGPHuJ9OMpASumSGEOD11+GSS7xJyTvv+NyfVWTCBL8J2quXhyASD2qHLult+3avWmnRwnvvtGsH8+ZVaTL/7DM4/XSfyOj++6vstJKBlNAlfX33HRx5JNx6Kxx0kLcV/OADOKDqprxdtgy6dvWboG+/DQ0aVNmpJQOpykXSz/bt8I9/wCOP7Ky0vvjiKqknLywEuOYaT+rTp/u0ciLxpBK6pI/Nm2HsWK9WueUWr2aZMMErras4mW/ZAhdd5NOM9urlrVtE4k0JXdLDs8/6KIhnngmrVsHTT3v1SqdOVR7K/Pk+Psurr8J998HgwVUegmQoVblIavvkEx8U5aOP4OijvRniSSdB7doJCWf9ejj+eH9+7z2f2EikqiihS2r6+GPo3x/+8x+/49i/v08NV7duwkLKz4dzz/XeoOPH+/eKSFVSlYuklqVL4bzz4KijvD3g/ffDokVw770JTeavvgqHHw7vv++tJJXMJRFUQpfkF4KXyF9/3ccn37ABbrsNbroproNoRWPDBv9hMHgwHHaYV+X36JHQkCSDKaFLcgoBpk71Kd9efRWWLIEaNbzoe/vtXlGdQHl5MGyYT0yxbBlcf70PD6NxzSWRlNAluaxd633kn3vOE3qtWt6r89574ayzkqJnzoYNHsp//+vVLE8/vXPKUZFEUkKX5LB6tffkfOABz5hNm3rnoMsvh/r1Ex0d4P2VRo70KpZvv4VHH/WSeRU3cRcpkRK6JEYI8M03MG4cPPGEj4KYn+/zdj7wgE/jkySZcv16D3HQIE/khxzi1S1qkijJRgldqkZent/YHD8eJk3y9uPr1/u2gw7y9uPnnON1GEli2TKfhOKJJ7whTefOMGSID7RVTe3DJAkpoUt8rF7tbfgmT4a5cz2B//CDl7qPOMLn62zf3udha9s2aUrj+fkwZgw89hhMnOg/JDp2hOef19yfkvyU0KVy8vO9HmL2bE/an3wCM2d6kRa8bfjBB8OFF3oLlVNOgYYNExtzESH4j4bnnvNk/v330Ly5/2i49FJo0ybREYpERwldord9uw9UMns2zJoFn3/uPTW3bNm5T5s2Pn5K795epM3OTsqJM7dv98t4+WWfUvSrr7wBzYkn+rDp553nHVBFUokSuuy0caPfqCx4zJ/vj2+/9QrllSt37luzpo8H26OH96hp187rv/fcM1HRl2jzZvjiC//+mTYNpkzx5Z9+8pqek07yfkoXXwy7757oaEUqTgk9U2zc6El5+XLvPl/wWLLEH4sXe11DYXXqQKtW3oQwO9ufDzzQm3kcemjS9aL56Sev6Zk3zxP4t9/6gIszZvg9WfAaoKOP9jkvDj3US+T77ZfYuEViJaqEbmZdgMeA6sDTIYS/FNleG3gOOBJYDVwcQlgY21CFELyN9g8/7PpYu3bn88aNsGmTP69c6Ul8yRLfVlT9+tCsmT+OPBKysrzU3bKlJ+7GjZPmZiV4df2aNbBiBSxcCLm5/rx4sSfx3FzYunXn/rVr+33Xvn2hQwf/HmrTBqpXT9QViMRXmQndzKoDjwOnAEuB6WY2OoQwt9Bu1wBrQwitzOwS4EHg4ngEHDch+CM/f9fHTz/5Iy9v52P79pJfF7e8fTts27bzeds2rwcoSLybNu18lPb6xx89ltLsvjvssYeXrhs39rt7xx3nSbtpU2jSZOdyJQazCuHn/yylPYr+kxX9J9i40b+r1q/3x4YNO1//8IOXtleu9I+ksDp1/HLatPHmhAcd5PdgDz4Y9t47qb6PROIumhJ6RyA3hLAAwMxGAt2Awgm9GzAgsvwq8E8zsxBCiGGsADx71WQefnF/wJOKL/h/ArbrcmTbjiCKLO+yH0S1vPN1TaBmFPuVtVzNs44ZVLNCryHYzm2BalA7sk/B+mr+vMt+IXLsjRB+BL4BPiz0b1Vw3lC+ZfBkWpCQy/peqagaNfyHw557+qN+ff8O6tDBq0YaN/ZHVpbXBjVqpKQtUiCahN4UWFLo9VLgqJL2CSHkmdk6oCGwS6WsmfUEegI0b968QgE32r827fZdBXgCM/PFXV77ySLbrGBT6e/ZsZ/t2LZjfTXDqlXbsYxVw6pFEmo1fFtku1UvSLTVdr6vWjWoXs2Xq/trq1kDqtfYNeadoZdrOZ7vKfzazBNuzZr+XJFHzZpe5VG7tj/q1PHHHntAvXq+TglapGKq9KZoCGEoMBQgOzu7QqX3s+87irPvi2lYIiJpIZoOzMuAZoVeHxBZV+w+ZlYDqI/fHBURkSoSTUKfDrQ2s5ZmVgu4BBhdZJ/RwJWR5QuA9+NRfy4iIiUrs8olUifeB3gXb7b4bAhhjpkNBHJCCKOBZ4DnzSwXWIMnfRERqUJR1aGHEMYB44qsu6vQ8hbgwtiGJiIi5aFBQEVE0oQSuohImlBCFxFJE0roIiJpwhLVutDMVgGLKvj2RhTphZrCdC3JKV2uJV2uA3QtBVqEEBoXtyFhCb0yzCwnhJCd6DhiQdeSnNLlWtLlOkDXEg1VuYiIpAkldBGRNJGqCX1oogOIIV1LckqXa0mX6wBdS5lSsg5dRER+LlVL6CIiUoQSuohImki5hG5mXcxsnpnlmlm/RMdTFjNbaGafmdksM8uJrNvbzMab2deR5waR9WZmgyLX9qmZdUhw7M+a2Uoz+7zQunLHbmZXRvb/2syuLO5cCbqWAWa2LPLZzDKz0wttuy1yLfPM7LRC6xP692dmzcxsopnNNbM5ZnZDZH3KfS6lXEsqfi67mdnHZjY7ci13R9a3NLNpkbhejgxBjpnVjrzOjWzPKusaoxJCSJkHPnzvfOBAoBYwG2ib6LjKiHkh0KjIur8C/SLL/YAHI8unA2/jk+B1AqYlOPZfAx2AzysaO7A3sCDy3CCy3CBJrmUA8Mdi9m0b+duqDbSM/M1VT4a/P6AJ0CGyvCfwVSTelPtcSrmWVPxcDKgbWa4JTIv8e48CLomsfxLoHVn+A/BkZPkS4OXSrjHaOFKthL5jwuoQwjagYMLqVNMNGB5ZHg6cU2j9c8FNBfYysyYJiA+AEMJkfHz7wsob+2nA+BDCmhDCWmA80CXuwRdRwrWUpBswMoSwNYTwDZCL/+0l/O8vhLA8hPBJZHkD8AU+p2/KfS6lXEtJkvlzCSGEHyMva0YeATgJeDWyvujnUvB5vQqcbGZGydcYlVRL6MVNWF3aH0AyCMB7ZjbDfJJsgH1DCMsjy98B+0aWU+H6yht7sl9Tn0hVxLMF1RSkyLVEfqYfgZcGU/pzKXItkIKfi5lVN7NZwEr8C3I+8EMIIa+YuHbEHNm+DmhIJa8l1RJ6KjouhNAB6Apca2a/Lrwx+O+slGw7msqxRzwB/AI4HFgO/C2h0ZSDmdUFXgNuDCGsL7wt1T6XYq4lJT+XEMJPIYTD8XmXOwIHV3UMqZbQo5mwOqmEEJZFnlcCr+Mf9IqCqpTI88rI7qlwfeWNPWmvKYSwIvI/YT7wFDt/2ib1tZhZTTwBjggh/DuyOiU/l+KuJVU/lwIhhB+AicDReBVXwcxwhePaEXNke31gNZW8llRL6NFMWJ00zKyOme1ZsAycCnzOrpNqXwm8GVkeDVwRaZnQCVhX6Gd0sihv7O8Cp5pZg8hP51Mj6xKuyP2Jc/HPBvxaLom0RGgJtAY+Jgn+/iL1rM8AX4QQHim0KeU+l5KuJUU/l8ZmtldkeXfgFPyewETggshuRT+Xgs/rAuD9yC+rkq4xOlV5JzgWD/yu/Vd4/VT/RMdTRqwH4nesZwNzCuLF68omAF8D/wH2DjvvlD8eubbPgOwEx/8S/pN3O16Xd01FYgeuxm/u5AJXJdG1PB+J9dPI/0hNCu3fP3It84CuyfL3BxyHV6d8CsyKPE5Pxc+llGtJxc/lMGBmJObPgbsi6w/EE3Iu8ApQO7J+t8jr3Mj2A8u6xmge6vovIpImUq3KRURESqCELiKSJpTQRUTShBK6iEiaUEIXEUkTSugiImlCCV1EJE38P3SKFb4Kd5BYAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "n_samples = 1000\n",
    "n_classes = 3\n",
    "\n",
    "t = (torch.rand(n_samples, n_classes) - .5) * 10\n",
    "probas = F.softmax(t, -1)\n",
    "sharpened_probas = Sharpen()(probas)\n",
    "plt.plot(probas.flatten().sort().values, color='r')\n",
    "plt.plot(sharpened_probas.flatten().sort().values, color='b')\n",
    "plt.show()\n",
    "test_gt(sharpened_probas[n_samples//2:].max(-1).values.sum().item(), probas[n_samples//2:].max(-1).values.sum().item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class Sequential(nn.Sequential):\n",
    "    \"\"\"Class that allows you to pass one or multiple inputs\"\"\"\n",
    "    def forward(self, *x):\n",
    "        for i, module in enumerate(self._modules.values()): \n",
    "            x = module(*x) if isinstance(x, (list, tuple, L)) else module(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class TimeDistributed(nn.Module):\n",
    "    def __init__(self, module, batch_first=False):\n",
    "        super(TimeDistributed, self).__init__()\n",
    "        self.module = module\n",
    "        self.batch_first = batch_first\n",
    "\n",
    "    def forward(self, x):\n",
    "\n",
    "        if len(x.size()) <= 2:\n",
    "            return self.module(x)\n",
    "\n",
    "        # Squash samples and timesteps into a single axis\n",
    "        x_reshape = x.contiguous().view(-1, x.size(-1))  # (samples * timesteps, input_size)\n",
    "\n",
    "        y = self.module(x_reshape)\n",
    "\n",
    "        # We have to reshape Y\n",
    "        if self.batch_first:\n",
    "            y = y.contiguous().view(x.size(0), -1, y.size(-1))  # (samples, timesteps, output_size)\n",
    "        else:\n",
    "            y = y.view(-1, x.size(1), y.size(-1))  # (timesteps, samples, output_size)\n",
    "\n",
    "        return y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class Temp_Scale(Module):\n",
    "    \"Used to perform Temperature Scaling (dirichlet=False) or Single-parameter Dirichlet calibration (dirichlet=True)\"\n",
    "    def __init__(self, temp=1., dirichlet=False):\n",
    "        self.weight = nn.Parameter(tensor(temp))\n",
    "        self.bias = None\n",
    "        self.log_softmax = dirichlet\n",
    "\n",
    "    def forward(self, x):\n",
    "        if self.log_softmax: x = F.log_softmax(x, dim=-1)\n",
    "        return x.div(self.weight)\n",
    "\n",
    "\n",
    "class Vector_Scale(Module):\n",
    "    \"Used to perform Vector Scaling (dirichlet=False) or Diagonal Dirichlet calibration (dirichlet=True)\"\n",
    "    def __init__(self, n_classes=1, dirichlet=False):\n",
    "        self.weight = nn.Parameter(torch.ones(n_classes))\n",
    "        self.bias = nn.Parameter(torch.zeros(n_classes))\n",
    "        self.log_softmax = dirichlet\n",
    "\n",
    "    def forward(self, x):\n",
    "        if self.log_softmax: x = F.log_softmax(x, dim=-1)\n",
    "        return x.mul(self.weight).add(self.bias)\n",
    "\n",
    "\n",
    "class Matrix_Scale(Module):\n",
    "    \"Used to perform Matrix Scaling (dirichlet=False) or Dirichlet calibration (dirichlet=True)\"\n",
    "    def __init__(self, n_classes=1, dirichlet=False):\n",
    "        self.ms = nn.Linear(n_classes, n_classes)\n",
    "        self.ms.weight.data = nn.Parameter(torch.eye(n_classes))\n",
    "        nn.init.constant_(self.ms.bias.data, 0.)\n",
    "        self.weight = self.ms.weight\n",
    "        self.bias = self.ms.bias\n",
    "        self.log_softmax = dirichlet\n",
    "\n",
    "    def forward(self, x):\n",
    "        if self.log_softmax: x = F.log_softmax(x, dim=-1)\n",
    "        return self.ms(x)\n",
    "    \n",
    "    \n",
    "def get_calibrator(calibrator=None, n_classes=1, **kwargs):\n",
    "    if calibrator is None or not calibrator: return noop\n",
    "    elif calibrator.lower() == 'temp': return Temp_Scale(dirichlet=False, **kwargs)\n",
    "    elif calibrator.lower() == 'vector': return Vector_Scale(n_classes=n_classes, dirichlet=False, **kwargs)\n",
    "    elif calibrator.lower() == 'matrix': return Matrix_Scale(n_classes=n_classes, dirichlet=False, **kwargs)\n",
    "    elif calibrator.lower() == 'dtemp': return Temp_Scale(dirichlet=True, **kwargs)\n",
    "    elif calibrator.lower() == 'dvector': return Vector_Scale(n_classes=n_classes, dirichlet=True, **kwargs)\n",
    "    elif calibrator.lower() == 'dmatrix': return Matrix_Scale(n_classes=n_classes, dirichlet=True, **kwargs)\n",
    "    else: assert False, f'please, select a correct calibrator instead of {calibrator}'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = 2\n",
    "c_out = 3\n",
    "\n",
    "t = torch.rand(bs, c_out)\n",
    "for calibrator, cal_name in zip(['temp', 'vector', 'matrix'], ['Temp_Scale', 'Vector_Scale', 'Matrix_Scale']): \n",
    "    cal = get_calibrator(calibrator, n_classes=c_out)\n",
    "#     print(calibrator)\n",
    "#     print(cal.weight, cal.bias, '\\n')\n",
    "    test_eq(cal(t), t)\n",
    "    test_eq(cal.__class__.__name__, cal_name)\n",
    "for calibrator, cal_name in zip(['dtemp', 'dvector', 'dmatrix'], ['Temp_Scale', 'Vector_Scale', 'Matrix_Scale']):\n",
    "    cal = get_calibrator(calibrator, n_classes=c_out)\n",
    "#     print(calibrator)\n",
    "#     print(cal.weight, cal.bias, '\\n')\n",
    "    test_eq(cal(t), F.log_softmax(t, dim=1))\n",
    "    test_eq(cal.__class__.__name__, cal_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = 2\n",
    "c_out = 3\n",
    "\n",
    "t = torch.rand(bs, c_out)\n",
    "\n",
    "test_eq(Temp_Scale()(t).shape, t.shape)\n",
    "test_eq(Vector_Scale(c_out)(t).shape, t.shape)\n",
    "test_eq(Matrix_Scale(c_out)(t).shape, t.shape)\n",
    "test_eq(Temp_Scale(dirichlet=True)(t).shape, t.shape)\n",
    "test_eq(Vector_Scale(c_out, dirichlet=True)(t).shape, t.shape)\n",
    "test_eq(Matrix_Scale(c_out, dirichlet=True)(t).shape, t.shape)\n",
    "\n",
    "test_eq(Temp_Scale()(t), t)\n",
    "test_eq(Vector_Scale(c_out)(t), t)\n",
    "test_eq(Matrix_Scale(c_out)(t), t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = 2\n",
    "c_out = 5\n",
    "\n",
    "t = torch.rand(bs, c_out)\n",
    "test_eq(Vector_Scale(c_out)(t), t)\n",
    "test_eq(Vector_Scale(c_out).weight.data, torch.ones(c_out))\n",
    "test_eq(Vector_Scale(c_out).weight.requires_grad, True)\n",
    "test_eq(type(Vector_Scale(c_out).weight), torch.nn.parameter.Parameter)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = 2\n",
    "c_out = 3\n",
    "weight = 2\n",
    "bias = 1\n",
    "\n",
    "t = torch.rand(bs, c_out)\n",
    "test_eq(Matrix_Scale(c_out)(t).shape, t.shape)\n",
    "test_eq(Matrix_Scale(c_out).weight.requires_grad, True)\n",
    "test_eq(type(Matrix_Scale(c_out).weight), torch.nn.parameter.Parameter)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class LogitAdjustmentLayer(Module):\n",
    "    \"Logit Adjustment for imbalanced datasets\"\n",
    "    def __init__(self, class_priors):\n",
    "        self.class_priors = class_priors\n",
    "    def forward(self, x):\n",
    "        return x.add(self.class_priors)\n",
    "    \n",
    "LogitAdjLayer = LogitAdjustmentLayer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs, n_classes = 16, 3\n",
    "class_priors = torch.rand(n_classes)\n",
    "logits = torch.randn(bs, n_classes) * 2\n",
    "test_eq(LogitAdjLayer(class_priors)(logits), logits + class_priors)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class PPV(Module):\n",
    "    def __init__(self, dim=-1): \n",
    "        self.dim = dim\n",
    "    def forward(self, x): \n",
    "        return torch.gt(x, 0).sum(dim=self.dim).float() / x.shape[self.dim]\n",
    "    def __repr__(self): return f'{self.__class__.__name__}(dim={self.dim})'\n",
    "    \n",
    "\n",
    "class PPAuc(Module):\n",
    "    def __init__(self, dim=-1): \n",
    "        self.dim = dim\n",
    "    def forward(self, x): \n",
    "        x = F.relu(x).sum(self.dim) / (abs(x).sum(self.dim) + 1e-8)\n",
    "        return x\n",
    "    def __repr__(self): return f'{self.__class__.__name__}(dim={self.dim})'\n",
    "    \n",
    "    \n",
    "class MaxPPVPool1d(Module):\n",
    "    \"Drop-in replacement for AdaptiveConcatPool1d - multiplies nf by 2\"\n",
    "    def forward(self, x):\n",
    "        _max = x.max(dim=-1).values\n",
    "        _ppv = torch.gt(x, 0).sum(dim=-1).float() / x.shape[-1]\n",
    "        return torch.cat((_max, _ppv), dim=-1).unsqueeze(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = 2\n",
    "nf = 5\n",
    "sl = 4\n",
    "\n",
    "t = torch.rand(bs, nf, sl)\n",
    "test_eq(MaxPPVPool1d()(t).shape, (bs, nf*2, 1))\n",
    "test_eq(MaxPPVPool1d()(t).shape, AdaptiveConcatPool1d(1)(t).shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class AdaptiveWeightedAvgPool1d(Module):\n",
    "    '''Global Pooling layer that performs a weighted average along the temporal axis\n",
    "    \n",
    "    It can be considered as a channel-wise form of local temporal attention. Inspired by the paper: \n",
    "    Hyun, J., Seong, H., & Kim, E. (2019). Universal Pooling--A New Pooling Method for Convolutional Neural Networks. arXiv preprint arXiv:1907.11440.'''\n",
    "\n",
    "    def __init__(self, n_in, seq_len, mult=2, n_layers=2, ln=False, dropout=0.5, act=nn.ReLU(), zero_init=True):\n",
    "        layers = nn.ModuleList()\n",
    "        for i in range(n_layers):\n",
    "            inp_mult = mult if i > 0 else 1\n",
    "            out_mult = mult if i < n_layers -1 else 1\n",
    "            p = dropout[i] if is_listy(dropout) else dropout\n",
    "            layers.append(LinLnDrop(seq_len * inp_mult, seq_len * out_mult, ln=False, p=p, \n",
    "                                    act=act if i < n_layers-1 and n_layers > 1 else None))\n",
    "        self.layers = layers\n",
    "        self.softmax = SoftMax(-1)\n",
    "        if zero_init: init_lin_zero(self)\n",
    "\n",
    "    def forward(self, x):\n",
    "        wap = x\n",
    "        for l in self.layers: wap = l(wap)\n",
    "        wap = self.softmax(wap)\n",
    "        return torch.mul(x, wap).sum(-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class GAP1d(Module):\n",
    "    \"Global Adaptive Pooling + Flatten\"\n",
    "    def __init__(self, output_size=1):\n",
    "        self.gap = nn.AdaptiveAvgPool1d(output_size)\n",
    "        self.flatten = Flatten()\n",
    "    def forward(self, x):\n",
    "        return self.flatten(self.gap(x))\n",
    "    \n",
    "    \n",
    "class GACP1d(Module):\n",
    "    \"Global AdaptiveConcatPool + Flatten\"\n",
    "    def __init__(self, output_size=1):\n",
    "        self.gacp = AdaptiveConcatPool1d(output_size)\n",
    "        self.flatten = Flatten()\n",
    "    def forward(self, x):\n",
    "        return self.flatten(self.gacp(x))\n",
    "    \n",
    "\n",
    "class GAWP1d(Module):\n",
    "    \"Global AdaptiveWeightedAvgPool1d + Flatten\"\n",
    "    def __init__(self, n_in, seq_len, n_layers=2, ln=False, dropout=0.5, act=nn.ReLU(), zero_init=False):\n",
    "        self.gacp = AdaptiveWeightedAvgPool1d(n_in, seq_len, n_layers=n_layers, ln=ln, dropout=dropout, act=act, zero_init=zero_init)\n",
    "        self.flatten = Flatten()\n",
    "    def forward(self, x):\n",
    "        return self.flatten(self.gacp(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "class GlobalWeightedAveragePool1d(Module):\n",
    "    \"\"\" Global Weighted Average Pooling layer \n",
    "    \n",
    "    Inspired by Building Efficient CNN Architecture for Offline Handwritten Chinese Character Recognition\n",
    "    https://arxiv.org/pdf/1804.01259.pdf\n",
    "    \"\"\"\n",
    "    \n",
    "    def __init__(self, n_in, seq_len): \n",
    "        self.weight = nn.Parameter(torch.ones(1, n_in, seq_len))\n",
    "        self.bias = nn.Parameter(torch.zeros(1, n_in, seq_len))\n",
    "\n",
    "    def forward(self, x):\n",
    "        α = F.softmax(torch.sigmoid(x * self.weight + self.bias), dim=-1)\n",
    "        return (x * α).sum(-1)\n",
    "\n",
    "GWAP1d = GlobalWeightedAveragePool1d\n",
    "\n",
    "def gwa_pool_head(n_in, c_out, seq_len, bn=True, fc_dropout=0.):\n",
    "    return nn.Sequential(GlobalWeightedAveragePool1d(n_in, seq_len), Flatten(), LinBnDrop(n_in, c_out, p=fc_dropout, bn=bn))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = torch.randn(16, 64, 50)\n",
    "head = gwa_pool_head(64, 5, 50)\n",
    "test_eq(head(t).shape, (16, 5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class AttentionalPool1d(Module):\n",
    "    \"\"\"Global Adaptive Pooling layer inspired by Attentional Pooling for Action Recognition https://arxiv.org/abs/1711.01467\"\"\"\n",
    "    def __init__(self, n_in, c_out, bn=False): \n",
    "        store_attr()\n",
    "        self.bn = nn.BatchNorm1d(n_in) if bn else None\n",
    "        self.conv1 = Conv1d(n_in, 1, 1)\n",
    "        self.conv2 = Conv1d(n_in, c_out, 1)\n",
    "\n",
    "    def forward(self, x):\n",
    "        if self.bn is not None: x = self.bn(x) \n",
    "        return (self.conv1(x) @ self.conv2(x).transpose(1,2)).transpose(1,2)\n",
    "    \n",
    "class GAttP1d(nn.Sequential):\n",
    "    def __init__(self, n_in, c_out, bn=False):\n",
    "        super().__init__(AttentionalPool1d(n_in, c_out, bn=bn), Flatten())\n",
    "        \n",
    "def attentional_pool_head(n_in, c_out, seq_len=None, bn=True, **kwargs):\n",
    "    return nn.Sequential(AttentionalPool1d(n_in, c_out, bn=bn, **kwargs), Flatten())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs, c_in, seq_len = 16, 1, 50\n",
    "c_out = 3\n",
    "t = torch.rand(bs, c_in, seq_len)\n",
    "test_eq(GAP1d()(t).shape, (bs, c_in))\n",
    "test_eq(GACP1d()(t).shape, (bs, c_in*2))\n",
    "bs, c_in, seq_len = 16, 4, 50\n",
    "t = torch.rand(bs, c_in, seq_len)\n",
    "test_eq(GAP1d()(t).shape, (bs, c_in))\n",
    "test_eq(GACP1d()(t).shape, (bs, c_in*2))\n",
    "test_eq(GAWP1d(c_in, seq_len, n_layers=2, ln=False, dropout=0.5, act=nn.ReLU(), zero_init=False)(t).shape, (bs, c_in))\n",
    "test_eq(GAWP1d(c_in, seq_len, n_layers=2, ln=False, dropout=0.5, act=nn.ReLU(), zero_init=False)(t).shape, (bs, c_in))\n",
    "test_eq(GAWP1d(c_in, seq_len, n_layers=1, ln=False, dropout=0.5, zero_init=False)(t).shape, (bs, c_in))\n",
    "test_eq(GAWP1d(c_in, seq_len, n_layers=1, ln=False, dropout=0.5, zero_init=True)(t).shape, (bs, c_in))\n",
    "test_eq(AttentionalPool1d(c_in, c_out)(t).shape, (bs, c_out, 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs, c_in, seq_len = 16, 128, 50\n",
    "c_out = 14\n",
    "t = torch.rand(bs, c_in, seq_len)\n",
    "attp = attentional_pool_head(c_in, c_out)\n",
    "test_eq(attp(t).shape, (bs, c_out))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class GEGLU(Module):\n",
    "    def forward(self, x):\n",
    "        x, gates = x.chunk(2, dim=-1)\n",
    "        return x * F.gelu(gates)\n",
    "\n",
    "class ReGLU(Module):\n",
    "    def forward(self, x):\n",
    "        x, gates = x.chunk(2, dim=-1)\n",
    "        return x * F.relu(gates)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "pytorch_acts = [nn.ELU, nn.LeakyReLU, nn.PReLU, nn.ReLU, nn.ReLU6, nn.SELU, nn.CELU, nn.GELU, nn.Sigmoid, Mish, nn.Softplus,\n",
    "nn.Tanh, nn.Softmax, GEGLU, ReGLU]\n",
    "pytorch_act_names = [a.__name__.lower() for a in pytorch_acts]\n",
    "\n",
    "def get_act_fn(act, **act_kwargs):\n",
    "    if act is None: return\n",
    "    elif isinstance(act, nn.Module): return act\n",
    "    elif callable(act): return act(**act_kwargs)\n",
    "    idx = pytorch_act_names.index(act.lower())\n",
    "    return pytorch_acts[idx](**act_kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(get_act_fn(nn.ReLU).__repr__(), \"ReLU()\")\n",
    "test_eq(get_act_fn(nn.ReLU()).__repr__(), \"ReLU()\")\n",
    "test_eq(get_act_fn(nn.LeakyReLU, negative_slope=0.05).__repr__(), \"LeakyReLU(negative_slope=0.05)\")\n",
    "test_eq(get_act_fn('reglu').__repr__(), \"ReGLU()\")\n",
    "test_eq(get_act_fn('leakyrelu', negative_slope=0.05).__repr__(), \"LeakyReLU(negative_slope=0.05)\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def create_pool_head(n_in, c_out, seq_len=None, concat_pool=False, fc_dropout=0., bn=False, y_range=None, **kwargs):\n",
    "    if kwargs: print(f'{kwargs}  not being used')\n",
    "    if concat_pool: n_in*=2\n",
    "    layers = [GACP1d(1) if concat_pool else GAP1d(1)]\n",
    "    layers += [LinBnDrop(n_in, c_out, bn=bn, p=fc_dropout)]\n",
    "    if y_range: layers += [SigmoidRange(*y_range)]\n",
    "    return nn.Sequential(*layers)\n",
    "\n",
    "pool_head = create_pool_head\n",
    "average_pool_head = partial(pool_head, concat_pool=False)\n",
    "setattr(average_pool_head, \"__name__\", \"average_pool_head\")\n",
    "concat_pool_head = partial(pool_head, concat_pool=True)\n",
    "setattr(concat_pool_head, \"__name__\", \"concat_pool_head\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential(\n",
       "  (0): GACP1d(\n",
       "    (gacp): AdaptiveConcatPool1d(\n",
       "      (ap): AdaptiveAvgPool1d(output_size=1)\n",
       "      (mp): AdaptiveMaxPool1d(output_size=1)\n",
       "    )\n",
       "    (flatten): Flatten(full=False)\n",
       "  )\n",
       "  (1): LinBnDrop(\n",
       "    (0): BatchNorm1d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (1): Dropout(p=0.5, inplace=False)\n",
       "    (2): Linear(in_features=24, out_features=2, bias=False)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bs = 16\n",
    "nf = 12\n",
    "c_out = 2\n",
    "seq_len = 20\n",
    "t = torch.rand(bs, nf, seq_len)\n",
    "test_eq(create_pool_head(nf, c_out, seq_len, fc_dropout=0.5)(t).shape, (bs, c_out))\n",
    "test_eq(create_pool_head(nf, c_out, seq_len, concat_pool=True, fc_dropout=0.5)(t).shape, (bs, c_out))\n",
    "create_pool_head(nf, c_out, seq_len, concat_pool=True, bn=True, fc_dropout=.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def max_pool_head(n_in, c_out, seq_len, fc_dropout=0., bn=False, y_range=None, **kwargs):\n",
    "    if kwargs: print(f'{kwargs}  not being used')\n",
    "    layers = [nn.MaxPool1d(seq_len, **kwargs), Flatten()]\n",
    "    layers += [LinBnDrop(n_in, c_out, bn=bn, p=fc_dropout)]\n",
    "    if y_range: layers += [SigmoidRange(*y_range)]\n",
    "    return nn.Sequential(*layers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = 16\n",
    "nf = 12\n",
    "c_out = 2\n",
    "seq_len = 20\n",
    "t = torch.rand(bs, nf, seq_len)\n",
    "test_eq(max_pool_head(nf, c_out, seq_len, fc_dropout=0.5)(t).shape, (bs, c_out))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def create_pool_plus_head(*args, lin_ftrs=None, fc_dropout=0., concat_pool=True, bn_final=False, lin_first=False, y_range=None):\n",
    "    nf = args[0]\n",
    "    c_out = args[1]\n",
    "    if concat_pool: nf = nf * 2\n",
    "    lin_ftrs = [nf, 512, c_out] if lin_ftrs is None else [nf] + lin_ftrs + [c_out]\n",
    "    ps = L(fc_dropout)\n",
    "    if len(ps) == 1: ps = [ps[0]/2] * (len(lin_ftrs)-2) + ps\n",
    "    actns = [nn.ReLU(inplace=True)] * (len(lin_ftrs)-2) + [None]\n",
    "    pool = AdaptiveConcatPool1d() if concat_pool else nn.AdaptiveAvgPool1d(1)\n",
    "    layers = [pool, Flatten()]\n",
    "    if lin_first: layers.append(nn.Dropout(ps.pop(0)))\n",
    "    for ni,no,p,actn in zip(lin_ftrs[:-1], lin_ftrs[1:], ps, actns):\n",
    "        layers += LinBnDrop(ni, no, bn=True, p=p, act=actn, lin_first=lin_first)\n",
    "    if lin_first: layers.append(nn.Linear(lin_ftrs[-2], c_out))\n",
    "    if bn_final: layers.append(nn.BatchNorm1d(lin_ftrs[-1], momentum=0.01))\n",
    "    if y_range is not None: layers.append(SigmoidRange(*y_range))\n",
    "    return nn.Sequential(*layers)\n",
    "\n",
    "pool_plus_head = create_pool_plus_head"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential(\n",
       "  (0): AdaptiveConcatPool1d(\n",
       "    (ap): AdaptiveAvgPool1d(output_size=1)\n",
       "    (mp): AdaptiveMaxPool1d(output_size=1)\n",
       "  )\n",
       "  (1): Flatten(full=False)\n",
       "  (2): BatchNorm1d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (3): Dropout(p=0.25, inplace=False)\n",
       "  (4): Linear(in_features=24, out_features=512, bias=False)\n",
       "  (5): ReLU(inplace=True)\n",
       "  (6): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (7): Dropout(p=0.5, inplace=False)\n",
       "  (8): Linear(in_features=512, out_features=2, bias=False)\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bs = 16\n",
    "nf = 12\n",
    "c_out = 2\n",
    "seq_len = 20\n",
    "t = torch.rand(bs, nf, seq_len)\n",
    "test_eq(create_pool_plus_head(nf, c_out, seq_len, fc_dropout=0.5)(t).shape, (bs, c_out))\n",
    "test_eq(create_pool_plus_head(nf, c_out, concat_pool=True, fc_dropout=0.5)(t).shape, (bs, c_out))\n",
    "create_pool_plus_head(nf, c_out, seq_len, fc_dropout=0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def create_conv_head(*args, adaptive_size=None, y_range=None):\n",
    "    nf = args[0]\n",
    "    c_out = args[1]\n",
    "    layers = [nn.AdaptiveAvgPool1d(adaptive_size)] if adaptive_size is not None else []\n",
    "    for i in range(2):\n",
    "        if nf > 1: \n",
    "            layers += [ConvBlock(nf, nf // 2, 1)] \n",
    "            nf = nf//2\n",
    "        else: break\n",
    "    layers += [ConvBlock(nf, c_out, 1), GAP1d(1)]\n",
    "    if y_range: layers += [SigmoidRange(*y_range)]\n",
    "    return nn.Sequential(*layers)\n",
    "\n",
    "conv_head = create_conv_head"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential(\n",
       "  (0): ConvBlock(\n",
       "    (0): Conv1d(12, 6, kernel_size=(1,), stride=(1,), bias=False)\n",
       "    (1): BatchNorm1d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (2): ReLU()\n",
       "  )\n",
       "  (1): ConvBlock(\n",
       "    (0): Conv1d(6, 3, kernel_size=(1,), stride=(1,), bias=False)\n",
       "    (1): BatchNorm1d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (2): ReLU()\n",
       "  )\n",
       "  (2): ConvBlock(\n",
       "    (0): Conv1d(3, 2, kernel_size=(1,), stride=(1,), bias=False)\n",
       "    (1): BatchNorm1d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (2): ReLU()\n",
       "  )\n",
       "  (3): GAP1d(\n",
       "    (gap): AdaptiveAvgPool1d(output_size=1)\n",
       "    (flatten): Flatten(full=False)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bs = 16\n",
    "nf = 12\n",
    "c_out = 2\n",
    "seq_len = 20\n",
    "t = torch.rand(bs, nf, seq_len)\n",
    "test_eq(create_conv_head(nf, c_out, seq_len)(t).shape, (bs, c_out))\n",
    "test_eq(create_conv_head(nf, c_out, adaptive_size=50)(t).shape, (bs, c_out))\n",
    "create_conv_head(nf, c_out, 50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def create_mlp_head(nf, c_out, seq_len=None, flatten=True, fc_dropout=0., bn=False, lin_first=False, y_range=None):\n",
    "    if flatten: nf *= seq_len\n",
    "    layers = [Flatten()] if flatten else []\n",
    "    layers += [LinBnDrop(nf, c_out, bn=bn, p=fc_dropout, lin_first=lin_first)]\n",
    "    if y_range: layers += [SigmoidRange(*y_range)]\n",
    "    return nn.Sequential(*layers)\n",
    "\n",
    "mlp_head = create_mlp_head"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential(\n",
       "  (0): Flatten(full=False)\n",
       "  (1): LinBnDrop(\n",
       "    (0): BatchNorm1d(240, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (1): Dropout(p=0.5, inplace=False)\n",
       "    (2): Linear(in_features=240, out_features=2, bias=False)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bs = 16\n",
    "nf = 12\n",
    "c_out = 2\n",
    "seq_len = 20\n",
    "t = torch.rand(bs, nf, seq_len)\n",
    "test_eq(create_mlp_head(nf, c_out, seq_len, fc_dropout=0.5)(t).shape, (bs, c_out))\n",
    "t = torch.rand(bs, nf, seq_len)\n",
    "create_mlp_head(nf, c_out, seq_len, bn=True, fc_dropout=.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def create_fc_head(nf, c_out, seq_len=None, flatten=True, lin_ftrs=None, y_range=None, fc_dropout=0., bn=False, bn_final=False, act=nn.ReLU(inplace=True)):\n",
    "    if flatten: nf *= seq_len\n",
    "    layers = [Flatten()] if flatten else []\n",
    "    lin_ftrs = [nf, 512, c_out] if lin_ftrs is None else [nf] + lin_ftrs + [c_out]\n",
    "    if not is_listy(fc_dropout): fc_dropout = [fc_dropout]*(len(lin_ftrs) - 1)\n",
    "    actns = [act for _ in range(len(lin_ftrs) - 2)] + [None]\n",
    "    layers += [LinBnDrop(lin_ftrs[i], lin_ftrs[i+1], bn=bn and (i!=len(actns)-1 or bn_final), p=p, act=a) for i,(p,a) in enumerate(zip(fc_dropout+[0.], actns))]\n",
    "    if y_range is not None: layers.append(SigmoidRange(*y_range))\n",
    "    return nn.Sequential(*layers)\n",
    "\n",
    "fc_head = create_fc_head"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential(\n",
       "  (0): Flatten(full=False)\n",
       "  (1): LinBnDrop(\n",
       "    (0): BatchNorm1d(240, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (1): Dropout(p=0.5, inplace=False)\n",
       "    (2): Linear(in_features=240, out_features=2, bias=False)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bs = 16\n",
    "nf = 12\n",
    "c_out = 2\n",
    "seq_len = 20\n",
    "t = torch.rand(bs, nf, seq_len)\n",
    "test_eq(create_fc_head(nf, c_out, seq_len, fc_dropout=0.5)(t).shape, (bs, c_out))\n",
    "create_mlp_head(nf, c_out, seq_len, bn=True, fc_dropout=.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def create_rnn_head(*args, fc_dropout=0., bn=False, y_range=None):\n",
    "    nf = args[0]\n",
    "    c_out = args[1]\n",
    "    layers = [LastStep()]\n",
    "    layers += [LinBnDrop(nf, c_out, bn=bn, p=fc_dropout)]\n",
    "    if y_range: layers += [SigmoidRange(*y_range)]\n",
    "    return nn.Sequential(*layers)\n",
    "\n",
    "rnn_head = create_rnn_head"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential(\n",
       "  (0): LastStep()\n",
       "  (1): LinBnDrop(\n",
       "    (0): BatchNorm1d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (1): Dropout(p=0.5, inplace=False)\n",
       "    (2): Linear(in_features=12, out_features=2, bias=False)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bs = 16\n",
    "nf = 12\n",
    "c_out = 2\n",
    "seq_len = 20\n",
    "t = torch.rand(bs, nf, seq_len)\n",
    "test_eq(create_rnn_head(nf, c_out, seq_len, fc_dropout=0.5)(t).shape, (bs, c_out))\n",
    "create_rnn_head(nf, c_out, seq_len, bn=True, fc_dropout=.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "\n",
    "def imputation_head(c_in, c_out, seq_len=None, ks=1, y_range=None, fc_dropout=0.):\n",
    "    layers = [nn.Dropout(fc_dropout), nn.Conv1d(c_in, c_out, ks)]\n",
    "    if y_range is not None: \n",
    "        y_range = (tensor(y_range[0]), tensor(y_range[1]))\n",
    "        layers += [SigmoidRange(*y_range)]\n",
    "    return nn.Sequential(*layers)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential(\n",
       "  (0): Dropout(p=0.0, inplace=False)\n",
       "  (1): Conv1d(12, 2, kernel_size=(1,), stride=(1,))\n",
       "  (2): SigmoidRange(low=tensor([0.1000, 0.1000, 0.1000, 0.1000, 0.2000, 0.2000, 0.2000, 0.2000, 0.3000,\n",
       "          0.3000, 0.3000, 0.3000]), high=tensor([0.6000, 0.6000, 0.6000, 0.6000, 0.7000, 0.7000, 0.7000, 0.7000, 0.8000,\n",
       "          0.8000, 0.8000, 0.8000]))\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bs = 16\n",
    "nf = 12\n",
    "ni = 2\n",
    "seq_len = 20\n",
    "t = torch.rand(bs, nf, seq_len)\n",
    "head = imputation_head(nf, ni, seq_len=None, ks=1, y_range=None, fc_dropout=0.)\n",
    "test_eq(head(t).shape, (bs, ni, seq_len))\n",
    "head = imputation_head(nf, ni, seq_len=None, ks=1, y_range=(.3,.7), fc_dropout=0.)\n",
    "test_ge(head(t).min(), .3)\n",
    "test_le(head(t).max(), .7)\n",
    "y_range = (tensor([0.1000, 0.1000, 0.1000, 0.1000, 0.2000, 0.2000, 0.2000, 0.2000, 0.3000,\n",
    "                   0.3000, 0.3000, 0.3000]),\n",
    "           tensor([0.6000, 0.6000, 0.6000, 0.6000, 0.7000, 0.7000, 0.7000, 0.7000, 0.8000,\n",
    "                   0.8000, 0.8000, 0.8000]))\n",
    "test_ge(head(t).min(), .1)\n",
    "test_le(head(t).max(), .9)\n",
    "head = imputation_head(nf, ni, seq_len=None, ks=1, y_range=y_range, fc_dropout=0.)\n",
    "head"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "class create_conv_lin_3d_head(nn.Sequential):\n",
    "    \"Module to create a 3d output head\"\n",
    "\n",
    "    def __init__(self, n_in, n_out, seq_len, d=(), conv_first=True, conv_bn=True, lin_first=False, lin_bn=True, act=None, fc_dropout=0., **kwargs):\n",
    "\n",
    "        assert len(d) == 2, \"you must pass a tuple of len == 2 to create a 3d output\"\n",
    "        conv = [BatchNorm(n_in, ndim=1)] if conv_bn else []\n",
    "        conv.append(Conv1d(n_in, d[0], 1, padding=0, bias=not conv_bn, **kwargs))\n",
    "\n",
    "        l = [Transpose(-1, -2), BatchNorm(n_out if lin_first else seq_len, ndim=1), Transpose(-1, -2)] if lin_bn else []\n",
    "        if fc_dropout != 0: l.append(nn.Dropout(fc_dropout))\n",
    "\n",
    "        lin = [nn.Linear(seq_len, d[1], bias=not lin_bn)]\n",
    "        if act is not None: lin.append(act)\n",
    "\n",
    "        lin_layers = lin+l if lin_first else l+lin\n",
    "        layers = conv + lin_layers if conv_first else lin_layers + conv\n",
    "\n",
    "        super().__init__(*layers)\n",
    "        \n",
    "conv_lin_3d_head = create_conv_lin_3d_head"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "create_conv_lin_3d_head(\n",
       "  (0): BatchNorm1d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (1): Conv1d(3, 2, kernel_size=(1,), stride=(1,), bias=False)\n",
       "  (2): Transpose(-1, -2)\n",
       "  (3): BatchNorm1d(50, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (4): Transpose(-1, -2)\n",
       "  (5): Linear(in_features=50, out_features=10, bias=False)\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "t = torch.randn(16, 3, 50)\n",
    "head = conv_lin_3d_head(3, 20, 50, (4,5))\n",
    "test_eq(head(t).shape, (16, 4, 5))\n",
    "head = conv_lin_3d_head(3, 20, 50, (2, 10))\n",
    "test_eq(head(t).shape, (16, 2, 10))\n",
    "head"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "class create_lin_3d_head(nn.Sequential):\n",
    "    \"Module to create a 3d output head with linear layers\"\n",
    "\n",
    "    def __init__(self, n_in, n_out, seq_len, d=(), lin_first=False, bn=True, act=None, fc_dropout=0.):\n",
    "\n",
    "        assert len(d) == 2, \"you must pass a tuple of len == 2 to create a 3d output\"\n",
    "        layers = [Flatten()]\n",
    "        layers += LinBnDrop(n_in * seq_len, n_out, bn=bn, p=fc_dropout, act=act, lin_first=lin_first)\n",
    "        layers += [Reshape(*d)]\n",
    "\n",
    "        super().__init__(*layers)\n",
    "        \n",
    "lin_3d_head = create_lin_3d_head"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "create_lin_3d_head(\n",
       "  (0): Flatten(full=False)\n",
       "  (1): BatchNorm1d(3200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (2): Linear(in_features=3200, out_features=5, bias=False)\n",
       "  (3): Reshape(bs, 5, 1)\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "t = torch.randn(16, 64, 50)\n",
    "head = lin_3d_head(64, 10, 50, (5,2))\n",
    "test_eq(head(t).shape, (16, 5, 2))\n",
    "head = lin_3d_head(64, 5, 50, (5, 1))\n",
    "test_eq(head(t).shape, (16, 5, 1))\n",
    "head"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "class create_conv_3d_head(nn.Sequential):\n",
    "    \"Module to create a 3d output head with a convolutional layer\"\n",
    "    def __init__(self, n_in, c_out, seq_len, d=(), lin_first=False, bn=True, act=None, fc_dropout=0.):\n",
    "        assert len(d) == 2, \"you must pass a tuple of len == 2 to create a 3d output\"\n",
    "        assert d[1] == seq_len, 'You can only use this head when learn.dls.len == learn.dls.d'\n",
    "        super().__init__(Conv(n_in, d[0], 1))\n",
    "        \n",
    "conv_3d_head = create_conv_3d_head"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = 16\n",
    "c_out = 4\n",
    "seq_len = 50\n",
    "d = (2,50)\n",
    "nf = 128\n",
    "t = torch.rand(bs, nf, seq_len)\n",
    "test_eq(conv_3d_head(nf, c_out, seq_len, d)(t).shape, (bs, *d))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def universal_pool_head(n_in, c_out, seq_len, mult=2, pool_n_layers=2, pool_ln=True, pool_dropout=0.5, pool_act=nn.ReLU(),\n",
    "                        zero_init=True, bn=True, fc_dropout=0.):\n",
    "    return nn.Sequential(AdaptiveWeightedAvgPool1d(n_in, seq_len, n_layers=pool_n_layers, mult=mult, ln=pool_ln, dropout=pool_dropout, act=pool_act), \n",
    "                         Flatten(), LinBnDrop(n_in, c_out, p=fc_dropout, bn=bn))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs, c_in, seq_len = 16, 128, 50\n",
    "c_out = 14\n",
    "t = torch.rand(bs, c_in, seq_len)\n",
    "uph = universal_pool_head(c_in, c_out, seq_len)\n",
    "test_eq(uph(t).shape, (bs, c_out))\n",
    "uph = universal_pool_head(c_in, c_out, seq_len, 2)\n",
    "test_eq(uph(t).shape, (bs, c_out))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "heads = [mlp_head, fc_head, average_pool_head, max_pool_head, concat_pool_head, pool_plus_head, conv_head, rnn_head, \n",
    "         conv_lin_3d_head, lin_3d_head, conv_3d_head, attentional_pool_head, universal_pool_head, gwa_pool_head]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "create_mlp_head\n",
      "create_fc_head\n",
      "average_pool_head\n",
      "max_pool_head\n",
      "concat_pool_head\n",
      "create_pool_plus_head\n",
      "create_conv_head\n",
      "create_rnn_head\n",
      "create_conv_lin_3d_head\n",
      "create_lin_3d_head\n",
      "create_conv_3d_head\n",
      "attentional_pool_head\n",
      "universal_pool_head\n",
      "gwa_pool_head\n"
     ]
    }
   ],
   "source": [
    "bs, c_in, seq_len = 16, 128, 50\n",
    "c_out = 14\n",
    "d = (7, 2)\n",
    "t = torch.rand(bs, c_in, seq_len)\n",
    "for head in heads: \n",
    "    print(head.__name__)\n",
    "    if head.__name__ == 'create_conv_3d_head': \n",
    "        test_eq(head(c_in, c_out, seq_len, (d[0], seq_len))(t).shape, (bs, *(d[0], seq_len)))\n",
    "    elif '3d' in head.__name__: \n",
    "        test_eq(head(c_in, c_out, seq_len, d)(t).shape, (bs, *d))\n",
    "    else: \n",
    "        test_eq(head(c_in, c_out, seq_len)(t).shape, (bs, c_out))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class SqueezeExciteBlock(Module):\n",
    "    def __init__(self, ni, reduction=16):\n",
    "        self.avg_pool = GAP1d(1)\n",
    "        self.fc = nn.Sequential(nn.Linear(ni, ni // reduction, bias=False), nn.ReLU(),  nn.Linear(ni // reduction, ni, bias=False), nn.Sigmoid())\n",
    "\n",
    "    def forward(self, x):\n",
    "        y = self.avg_pool(x)\n",
    "        y = self.fc(y).unsqueeze(2)\n",
    "        return x * y.expand_as(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = 2\n",
    "ni = 32\n",
    "sl = 4\n",
    "t = torch.rand(bs, ni, sl)\n",
    "test_eq(SqueezeExciteBlock(ni)(t).shape, (bs, ni, sl))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class GaussianNoise(Module):\n",
    "    \"\"\"Gaussian noise regularizer.\n",
    "\n",
    "    Args:\n",
    "        sigma (float, optional): relative standard deviation used to generate the\n",
    "            noise. Relative means that it will be multiplied by the magnitude of\n",
    "            the value your are adding the noise to. This means that sigma can be\n",
    "            the same regardless of the scale of the vector.\n",
    "        is_relative_detach (bool, optional): whether to detach the variable before\n",
    "            computing the scale of the noise. If `False` then the scale of the noise\n",
    "            won't be seen as a constant but something to optimize: this will bias the\n",
    "            network to generate vectors with smaller values.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, sigma=0.1, is_relative_detach=True):\n",
    "        self.sigma, self.is_relative_detach = sigma, is_relative_detach\n",
    "\n",
    "    def forward(self, x):\n",
    "        if self.training and self.sigma not in [0, None]:\n",
    "            scale = self.sigma * (x.detach() if self.is_relative_detach else x)\n",
    "            sampled_noise = torch.empty(x.size(), device=x.device).normal_() * scale\n",
    "            x = x + sampled_noise\n",
    "        return x "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = torch.ones(2,3,4)\n",
    "test_ne(GaussianNoise()(t), t)\n",
    "test_eq(GaussianNoise()(t).shape, t.shape)\n",
    "t = torch.ones(2,3)\n",
    "test_ne(GaussianNoise()(t), t)\n",
    "test_eq(GaussianNoise()(t).shape, t.shape)\n",
    "t = torch.ones(2)\n",
    "test_ne(GaussianNoise()(t), t)\n",
    "test_eq(GaussianNoise()(t).shape, t.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def gambler_loss(reward=2):\n",
    "    def _gambler_loss(model_output, targets):\n",
    "        outputs = torch.nn.functional.softmax(model_output, dim=1)\n",
    "        outputs, reservation = outputs[:, :-1], outputs[:, -1]\n",
    "        gain = torch.gather(outputs, dim=1, index=targets.unsqueeze(1)).squeeze()\n",
    "        doubling_rate = (gain + reservation / reward).log()\n",
    "        return - doubling_rate.mean()\n",
    "    return _gambler_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.7326)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_output = torch.rand(16, 3)\n",
    "targets = torch.randint(0, 2, (16,))\n",
    "criterion = gambler_loss(2)\n",
    "criterion(model_output, targets)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def CrossEntropyLossOneHot(output, target, **kwargs):\n",
    "    if target.ndim == 2: _, target = target.max(dim=1)\n",
    "    return nn.CrossEntropyLoss(**kwargs)(output, target)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.7428)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "output = torch.rand(16, 2)\n",
    "target = torch.randint(0, 2, (16,))\n",
    "CrossEntropyLossOneHot(output, target)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.7565, grad_fn=<NllLossBackward>)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from tsai.data.transforms import OneHot\n",
    "output = nn.Parameter(torch.rand(16, 2))\n",
    "target = torch.randint(0, 2, (16,))\n",
    "one_hot_target = OneHot()(target)\n",
    "CrossEntropyLossOneHot(output, one_hot_target)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "def proba_certainty(output):\n",
    "    if output.sum(-1).mean().item() != 1: output = F.softmax(output, -1)\n",
    "    return (output.max(-1).values - 1. / output.shape[-1])/( 1 - 1. / output.shape[-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.6781, 0.4493, 0.5333, 0.6873, 0.6694, 0.9724, 0.7729, 0.2818, 0.3398,\n",
       "        0.8349, 0.8810, 0.4944, 0.5647, 0.9025, 0.7630, 0.9613],\n",
       "       grad_fn=<DivBackward0>)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#hide\n",
    "target = random_shuffle(concat(torch.zeros(5), torch.ones(7), torch.ones(4) + 1)).long()\n",
    "output = nn.Parameter(5 * torch.rand((16, 3)) - 5 * torch.rand((16, 3)))\n",
    "proba_certainty(output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "def CrossEntropyLossOneHotWithUncertainty():\n",
    "    def _CrossEntropyLossOneHotWithUncertainty(output, target, **kwargs):\n",
    "        return (proba_certainty(output) * CrossEntropyLossOneHot(output, target, reduction='none', **kwargs)).mean()\n",
    "    return _CrossEntropyLossOneHotWithUncertainty"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ttest_ind:            t = -1.5827  p = 0.118873\n",
      "ttest_ind_from_stats: t = -1.5827  p = 0.118873\n",
      "formula:              t = -1.5827  p = 0.118873\n",
      "formula:              t = -1.5827\n"
     ]
    }
   ],
   "source": [
    "#hide\n",
    "# https://stackoverflow.com/questions/22611446/perform-2-sample-t-test\n",
    "\n",
    "from __future__ import print_function\n",
    "\n",
    "import numpy as np\n",
    "from scipy.stats import ttest_ind, ttest_ind_from_stats\n",
    "from scipy.special import stdtr\n",
    "\n",
    "np.random.seed(1)\n",
    "\n",
    "# Create sample data.\n",
    "a = np.random.randn(40)\n",
    "b = 4*np.random.randn(50)\n",
    "\n",
    "# Use scipy.stats.ttest_ind.\n",
    "t, p = ttest_ind(a, b, equal_var=False)\n",
    "print(\"ttest_ind:            t = %g  p = %g\" % (t, p))\n",
    "\n",
    "# Compute the descriptive statistics of a and b.\n",
    "abar = a.mean()\n",
    "avar = a.var(ddof=1)\n",
    "na = a.size\n",
    "adof = na - 1\n",
    "\n",
    "bbar = b.mean()\n",
    "bvar = b.var(ddof=1)\n",
    "nb = b.size\n",
    "bdof = nb - 1\n",
    "\n",
    "# Use scipy.stats.ttest_ind_from_stats.\n",
    "t2, p2 = ttest_ind_from_stats(abar, np.sqrt(avar), na,\n",
    "                              bbar, np.sqrt(bvar), nb,\n",
    "                              equal_var=False)\n",
    "print(\"ttest_ind_from_stats: t = %g  p = %g\" % (t2, p2))\n",
    "\n",
    "# Use the formulas directly.\n",
    "tf = (abar - bbar) / np.sqrt(avar/na + bvar/nb)\n",
    "dof = (avar/na + bvar/nb)**2 / (avar**2/(na**2*adof) + bvar**2/(nb**2*bdof))\n",
    "pf = 2*stdtr(dof, -np.abs(tf))\n",
    "\n",
    "print(\"formula:              t = %g  p = %g\" % (tf, pf))\n",
    "\n",
    "a = tensor(a)\n",
    "b = tensor(b)\n",
    "tf = (a.mean() - b.mean()) / torch.sqrt(a.var()/a.size(0) + b.var()/b.size(0))\n",
    "print(\"formula:              t = %g\" % (tf))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(-1.5827)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ttest_tensor(a, b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def ttest_bin_loss(output, target):\n",
    "    output = nn.Softmax(dim=-1)(output[:, 1])\n",
    "    return ttest_tensor(output[target == 0], output[target == 1])\n",
    "\n",
    "def ttest_reg_loss(output, target):\n",
    "    return ttest_tensor(output[target <= 0], output[target > 0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for _ in range(100):\n",
    "    output = torch.rand(256, 2)\n",
    "    target = torch.randint(0, 2, (256,))\n",
    "    test_close(ttest_bin_loss(output, target).item(), \n",
    "               ttest_ind(nn.Softmax(dim=-1)(output[:, 1])[target == 0], nn.Softmax(dim=-1)(output[:, 1])[target == 1], equal_var=False)[0], eps=1e-3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class CenterLoss(Module):\n",
    "    r\"\"\"\n",
    "    Code in Pytorch has been slightly modified from: https://github.com/KaiyangZhou/pytorch-center-loss/blob/master/center_loss.py\n",
    "    Based on paper: Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.\n",
    "\n",
    "    Args:\n",
    "        c_out (int): number of classes.\n",
    "        logits_dim (int): dim 1 of the logits. By default same as c_out (for one hot encoded logits)\n",
    "        \n",
    "    \"\"\"\n",
    "    def __init__(self, c_out, logits_dim=None):\n",
    "        logits_dim = ifnone(logits_dim, c_out)\n",
    "        self.c_out, self.logits_dim = c_out, logits_dim\n",
    "        self.centers = nn.Parameter(torch.randn(c_out, logits_dim))\n",
    "        self.classes = nn.Parameter(torch.arange(c_out).long(), requires_grad=False)\n",
    "\n",
    "    def forward(self, x, labels):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            x: feature matrix with shape (batch_size, logits_dim).\n",
    "            labels: ground truth labels with shape (batch_size).\n",
    "        \"\"\"\n",
    "        bs = x.shape[0]\n",
    "        distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(bs, self.c_out) + \\\n",
    "                  torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.c_out, bs).T\n",
    "        distmat = torch.addmm(distmat, x, self.centers.T, beta=1, alpha=-2)\n",
    "\n",
    "        labels = labels.unsqueeze(1).expand(bs, self.c_out)\n",
    "        mask = labels.eq(self.classes.expand(bs, self.c_out))\n",
    "\n",
    "        dist = distmat * mask.float()\n",
    "        loss = dist.clamp(min=1e-12, max=1e+12).sum() / bs\n",
    "\n",
    "        return loss\n",
    "\n",
    "\n",
    "class CenterPlusLoss(Module):\n",
    "    \n",
    "    def __init__(self, loss, c_out, λ=1e-2, logits_dim=None):\n",
    "        self.loss, self.c_out, self.λ = loss, c_out, λ\n",
    "        self.centerloss = CenterLoss(c_out, logits_dim)\n",
    "        \n",
    "    def forward(self, x, labels):\n",
    "        return self.loss(x, labels) + self.λ * self.centerloss(x, labels)\n",
    "    def __repr__(self): return f\"CenterPlusLoss(loss={self.loss}, c_out={self.c_out}, λ={self.λ})\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor(8.8957, grad_fn=<DivBackward0>),\n",
       " TensorBase(2.3989, grad_fn=<AliasBackward>))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "c_in = 10\n",
    "x = torch.rand(64, c_in).to(device=default_device())\n",
    "x = F.softmax(x, dim=1)\n",
    "label = x.max(dim=1).indices\n",
    "CenterLoss(c_in).to(x.device)(x, label), CenterPlusLoss(LabelSmoothingCrossEntropyFlat(), c_in).to(x.device)(x, label)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "CenterPlusLoss(loss=FlattenedLoss of LabelSmoothingCrossEntropy(), c_out=10, λ=0.01)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "CenterPlusLoss(LabelSmoothingCrossEntropyFlat(), c_in)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class FocalLoss(Module):\n",
    "\n",
    "    def __init__(self, gamma=0, eps=1e-7):\n",
    "        self.gamma, self.eps, self.ce = gamma, eps, CrossEntropyLossFlat()\n",
    "\n",
    "    def forward(self, input, target):\n",
    "        logp = self.ce(input, target)\n",
    "        p = torch.exp(-logp)\n",
    "        loss = (1 - p) ** self.gamma * logp\n",
    "        return loss.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "TensorBase(0.7457)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "c_in = 10\n",
    "x = torch.rand(64, c_in).to(device=default_device())\n",
    "x = F.softmax(x, dim=1)\n",
    "label = x.max(dim=1).indices\n",
    "FocalLoss(c_in).to(x.device)(x, label)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class TweedieLoss(Module):\n",
    "    def __init__(self, p=1.5, eps=1e-10):\n",
    "        \"\"\"\n",
    "        Tweedie loss as calculated in LightGBM\n",
    "        Args:\n",
    "            p: tweedie variance power (1 < p < 2)\n",
    "            eps: small number to avoid log(zero).\n",
    "        \"\"\"\n",
    "        assert p > 1 and p < 2, \"make sure 1 < p < 2\"\n",
    "        self.p, self.eps = p, eps\n",
    "\n",
    "    def forward(self, inp, targ):\n",
    "        inp = inp.flatten()\n",
    "        targ = targ.flatten()\n",
    "        torch.clamp_min_(inp, self.eps)\n",
    "        a = targ * torch.exp((1 - self.p) * torch.log(inp)) / (1 - self.p)\n",
    "        b = torch.exp((2 - self.p) * torch.log(inp)) / (2 - self.p)\n",
    "        loss = -a + b\n",
    "        return loss.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(3.2491)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "c_in = 10\n",
    "output = torch.rand(64).to(device=default_device())\n",
    "target = torch.rand(64).to(device=default_device())\n",
    "TweedieLoss().to(output.device)(output, target)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "class PositionwiseFeedForward(nn.Sequential):\n",
    "    def __init__(self, dim, dropout=0., act='reglu', mlp_ratio=1):\n",
    "        act_mult = 2 if act.lower() in [\"geglu\", \"reglu\"] else 1\n",
    "        super().__init__(nn.Linear(dim, dim * mlp_ratio * act_mult),\n",
    "                         get_act_fn(act),\n",
    "                         nn.Dropout(dropout),\n",
    "                         nn.Linear(dim * mlp_ratio, dim),\n",
    "                         nn.Dropout(dropout))\n",
    "\n",
    "class TokenLayer(Module):\n",
    "    def __init__(self, token=True): self.token = token\n",
    "    def forward(self, x): return x[..., 0] if self.token is not None else x.mean(-1)\n",
    "    def __repr__(self): return f\"{self.__class__.__name__}()\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = torch.randn(2,3,10)\n",
    "m = PositionwiseFeedForward(10, dropout=0., act='reglu', mlp_ratio=1)\n",
    "test_eq(m(t).shape, t.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class ScaledDotProductAttention(Module):\n",
    "    \"\"\"Scaled Dot-Product Attention module (Vaswani et al., 2017) with optional residual attention from previous layer (He et al, 2020)\"\"\"\n",
    "\n",
    "    def __init__(self, attn_dropout=0., res_attention=False):  \n",
    "        self.attn_dropout = nn.Dropout(attn_dropout)\n",
    "        self.res_attention = res_attention\n",
    "\n",
    "    def forward(self, q:Tensor, k:Tensor, v:Tensor, prev:Optional[Tensor]=None, key_padding_mask:Optional[Tensor]=None, attn_mask:Optional[Tensor]=None):\n",
    "        '''\n",
    "        Input shape:\n",
    "            q               : [bs x n_heads x max_q_len x d_k]\n",
    "            k               : [bs x n_heads x d_k x seq_len]\n",
    "            v               : [bs x n_heads x seq_len x d_v]\n",
    "            prev            : [bs x n_heads x q_len x seq_len]\n",
    "            key_padding_mask: [bs x seq_len]\n",
    "            attn_mask       : [1 x seq_len x seq_len]\n",
    "\n",
    "        Output shape: \n",
    "            output:  [bs x n_heads x q_len x d_v]\n",
    "            attn   : [bs x n_heads x q_len x seq_len]\n",
    "            scores : [bs x n_heads x q_len x seq_len]\n",
    "        '''\n",
    "\n",
    "        # Scaled MatMul (q, k) - similarity scores for all pairs of positions in an input sequence\n",
    "        attn_scores = torch.matmul(q / np.sqrt(q.shape[-2]), k)      # attn_scores : [bs x n_heads x max_q_len x q_len]\n",
    "\n",
    "        # Add pre-softmax attention scores from the previous layer (optional)\n",
    "        if prev is not None: attn_scores = attn_scores + prev \n",
    "\n",
    "        # Attention mask (optional)\n",
    "        if attn_mask is not None:                                     # attn_mask with shape [q_len x seq_len] - only used when q_len == seq_len\n",
    "            if attn_mask.dtype == torch.bool:\n",
    "                attn_scores.masked_fill_(attn_mask, -np.inf)\n",
    "            else:\n",
    "                attn_scores += attn_mask\n",
    "\n",
    "        # Key padding mask (optional)\n",
    "        if key_padding_mask is not None:                              # mask with shape [bs x q_len] (only when max_w_len == q_len)\n",
    "            attn_scores.masked_fill_(key_padding_mask.unsqueeze(1).unsqueeze(2), -np.inf)\n",
    "\n",
    "        # normalize the attention weights\n",
    "        attn_weights = F.softmax(attn_scores, dim=-1)                 # attn_weights   : [bs x n_heads x max_q_len x q_len]\n",
    "        attn_weights = self.attn_dropout(attn_weights)\n",
    "\n",
    "        # compute the new values given the attention weights\n",
    "        output = torch.matmul(attn_weights, v)                        # output: [bs x n_heads x max_q_len x d_v]\n",
    "\n",
    "        if self.res_attention: return output, attn_weights, attn_scores\n",
    "        else: return output, attn_weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor(-1.2480e-10, grad_fn=<MeanBackward0>),\n",
       " tensor(0.4757, grad_fn=<StdBackward>))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "B = 16\n",
    "C = 10\n",
    "M = 1500 # seq_len\n",
    "\n",
    "n_heads = 1\n",
    "D = 128 # model dimension\n",
    "N = 512 # max_seq_len - latent's index dimension\n",
    "d_k = D // n_heads\n",
    "\n",
    "xb = torch.randn(B, C, M)\n",
    "xb = (xb - xb.mean()) / xb.std()\n",
    "\n",
    "# Attention\n",
    "# input (Q)\n",
    "lin = nn.Linear(M, N, bias=False)\n",
    "Q = lin(xb).transpose(1,2)\n",
    "test_eq(Q.shape, (B, N, C))\n",
    "\n",
    "# q\n",
    "to_q = nn.Linear(C, D, bias=False)\n",
    "q = to_q(Q)\n",
    "q = nn.LayerNorm(D)(q)\n",
    "\n",
    "# k, v\n",
    "context = xb.transpose(1,2)\n",
    "to_kv = nn.Linear(C, D * 2, bias=False)\n",
    "k, v = to_kv(context).chunk(2, dim = -1)\n",
    "k = k.transpose(-1, -2)\n",
    "k = nn.LayerNorm(M)(k)\n",
    "v = nn.LayerNorm(D)(v)\n",
    "\n",
    "test_eq(q.shape, (B, N, D))\n",
    "test_eq(k.shape, (B, D, M))\n",
    "test_eq(v.shape, (B, M, D))\n",
    "\n",
    "output, attn, scores = ScaledDotProductAttention(res_attention=True)(q.unsqueeze(1), k.unsqueeze(1), v.unsqueeze(1))\n",
    "test_eq(output.shape, (B, 1, N, D))\n",
    "test_eq(attn.shape, (B, 1, N, M))\n",
    "test_eq(scores.shape, (B, 1, N, M))\n",
    "scores.mean(), scores.std()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# #hide\n",
    "# class MultiheadAttention(Module):\n",
    "#     def __init__(self, d_model:int, n_heads:int, d_k:Optional[int]=None, d_v:Optional[int]=None, res_attention:bool=False, \n",
    "#                  dropout:float=0., qkv_bias:bool=True):\n",
    "#         \"\"\"Multi Head Attention Layer\n",
    "\n",
    "#         Input shape:\n",
    "#             Q:       [batch_size (bs) x max_q_len x d_model]\n",
    "#             K, V:    [batch_size (bs) x q_len x d_model]\n",
    "#             mask:    [q_len x q_len]\n",
    "#         \"\"\"\n",
    "\n",
    "#         d_k = ifnone(d_k, d_model // n_heads)\n",
    "#         d_v = ifnone(d_v, d_model // n_heads)\n",
    "\n",
    "#         self.n_heads, self.d_k, self.d_v = n_heads, d_k, d_v\n",
    "\n",
    "#         self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=qkv_bias)\n",
    "#         self.W_K = nn.Linear(d_model, d_k * n_heads, bias=qkv_bias)\n",
    "#         self.W_V = nn.Linear(d_model, d_v * n_heads, bias=qkv_bias)\n",
    "\n",
    "#         # Scaled Dot-Product Attention (multiple heads)\n",
    "#         self.res_attention = res_attention\n",
    "#         self.sdp_attn = ScaledDotProductAttention(res_attention=self.res_attention)\n",
    "\n",
    "#         # Poject output\n",
    "#         project_out = not (n_heads == 1 and d_model == d_k)\n",
    "#         self.to_out = nn.Sequential(nn.Linear(n_heads * d_v, d_model), nn.Dropout(dropout)) if project_out else nn.Identity()\n",
    "\n",
    "\n",
    "#     def forward(self, Q:Tensor, K:Optional[Tensor]=None, V:Optional[Tensor]=None, prev:Optional[Tensor]=None,\n",
    "#                 key_padding_mask:Optional[Tensor]=None, attn_mask:Optional[Tensor]=None):\n",
    "\n",
    "#         bs = Q.size(0)\n",
    "#         if K is None: K = Q\n",
    "#         if V is None: V = Q\n",
    "\n",
    "#         # Linear (+ split in multiple heads)\n",
    "#         q_s = self.W_Q(Q).view(bs, -1, self.n_heads, self.d_k).transpose(1,2)       # q_s    : [bs x n_heads x max_q_len x d_k]\n",
    "#         k_s = self.W_K(K).view(bs, -1, self.n_heads, self.d_k).permute(0,2,3,1)     # k_s    : [bs x n_heads x d_k x q_len] - transpose(1,2) + transpose(2,3)\n",
    "#         v_s = self.W_V(V).view(bs, -1, self.n_heads, self.d_v).transpose(1,2)       # v_s    : [bs x n_heads x q_len x d_v]\n",
    "\n",
    "#         # Apply Scaled Dot-Product Attention (multiple heads)\n",
    "#         if self.res_attention:\n",
    "#             output, attn_weights, attn_scores = self.sdp_attn(q_s, k_s, v_s, prev=prev, key_padding_mask=key_padding_mask, attn_mask=attn_mask)\n",
    "#         else:\n",
    "#             output, attn_weights = self.sdp_attn(q_s, k_s, v_s, key_padding_mask=key_padding_mask, attn_mask=attn_mask)\n",
    "#         # output: [bs x n_heads x q_len x d_v], attn: [bs x n_heads x q_len x q_len], scores: [bs x n_heads x max_q_len x q_len]\n",
    "\n",
    "#         # back to the original inputs dimensions\n",
    "#         output = output.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * self.d_v) # output: [bs x q_len x n_heads * d_v]\n",
    "#         output = self.to_out(output)\n",
    "\n",
    "#         if self.res_attention: return output, attn_weights, attn_scores\n",
    "#         else: return output, attn_weights "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class MultiheadAttention(Module):\n",
    "    def __init__(self, d_model, n_heads, d_k=None, d_v=None, res_attention=False, attn_dropout=0., proj_dropout=0., qkv_bias=True):\n",
    "        \"\"\"Multi Head Attention Layer\n",
    "\n",
    "        Input shape:\n",
    "            Q:       [batch_size (bs) x max_q_len x d_model]\n",
    "            K, V:    [batch_size (bs) x q_len x d_model]\n",
    "            mask:    [q_len x q_len]\n",
    "        \"\"\"\n",
    "\n",
    "        d_k = ifnone(d_k, d_model // n_heads)\n",
    "        d_v = ifnone(d_v, d_model // n_heads)\n",
    "\n",
    "        self.n_heads, self.d_k, self.d_v = n_heads, d_k, d_v\n",
    "\n",
    "        self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=qkv_bias)\n",
    "        self.W_K = nn.Linear(d_model, d_k * n_heads, bias=qkv_bias)\n",
    "        self.W_V = nn.Linear(d_model, d_v * n_heads, bias=qkv_bias)\n",
    "\n",
    "        # Scaled Dot-Product Attention (multiple heads)\n",
    "        self.res_attention = res_attention\n",
    "        self.sdp_attn = ScaledDotProductAttention(attn_dropout=attn_dropout, res_attention=self.res_attention)\n",
    "\n",
    "        # Poject output\n",
    "        self.to_out = nn.Sequential(nn.Linear(n_heads * d_v, d_model), nn.Dropout(proj_dropout))\n",
    "\n",
    "\n",
    "    def forward(self, Q:Tensor, K:Optional[Tensor]=None, V:Optional[Tensor]=None, prev:Optional[Tensor]=None,\n",
    "                key_padding_mask:Optional[Tensor]=None, attn_mask:Optional[Tensor]=None):\n",
    "\n",
    "        bs = Q.size(0)\n",
    "        if K is None: K = Q\n",
    "        if V is None: V = Q\n",
    "\n",
    "        # Linear (+ split in multiple heads)\n",
    "        q_s = self.W_Q(Q).view(bs, -1, self.n_heads, self.d_k).transpose(1,2)       # q_s    : [bs x n_heads x max_q_len x d_k]\n",
    "        k_s = self.W_K(K).view(bs, -1, self.n_heads, self.d_k).permute(0,2,3,1)     # k_s    : [bs x n_heads x d_k x q_len] - transpose(1,2) + transpose(2,3)\n",
    "        v_s = self.W_V(V).view(bs, -1, self.n_heads, self.d_v).transpose(1,2)       # v_s    : [bs x n_heads x q_len x d_v]\n",
    "\n",
    "        # Apply Scaled Dot-Product Attention (multiple heads)\n",
    "        if self.res_attention:\n",
    "            output, attn_weights, attn_scores = self.sdp_attn(q_s, k_s, v_s, prev=prev, key_padding_mask=key_padding_mask, attn_mask=attn_mask)\n",
    "        else:\n",
    "            output, attn_weights = self.sdp_attn(q_s, k_s, v_s, key_padding_mask=key_padding_mask, attn_mask=attn_mask)\n",
    "        # output: [bs x n_heads x q_len x d_v], attn: [bs x n_heads x q_len x q_len], scores: [bs x n_heads x max_q_len x q_len]\n",
    "\n",
    "        # back to the original inputs dimensions\n",
    "        output = output.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * self.d_v) # output: [bs x q_len x n_heads * d_v]\n",
    "        output = self.to_out(output)\n",
    "\n",
    "        if self.res_attention: return output, attn_weights, attn_scores\n",
    "        else: return output, attn_weights "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "attn_mask torch.Size([50, 50]) key_padding_mask torch.Size([16, 50])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(torch.Size([16, 3, 50, 6]), torch.Size([16, 3, 50, 50]))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "q = torch.rand([16, 3, 50, 8]) \n",
    "k = torch.rand([16, 3, 50, 8]).transpose(-1, -2)\n",
    "v = torch.rand([16, 3, 50, 6])\n",
    "attn_mask = torch.triu(torch.ones(50, 50)) # shape: q_len x q_len\n",
    "key_padding_mask = torch.zeros(16, 50)\n",
    "key_padding_mask[[1, 3, 6, 15], -10:] = 1\n",
    "key_padding_mask = key_padding_mask.bool()\n",
    "print('attn_mask', attn_mask.shape, 'key_padding_mask', key_padding_mask.shape)\n",
    "output, attn = ScaledDotProductAttention(attn_dropout=.1)(q, k, v, attn_mask=attn_mask, key_padding_mask=key_padding_mask)\n",
    "output.shape, attn.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([16, 50, 128]), torch.Size([16, 3, 50, 50]))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "t = torch.rand(16, 50, 128)\n",
    "output, attn = MultiheadAttention(d_model=128, n_heads=3, d_k=8, d_v=6)(t, t, t, key_padding_mask=key_padding_mask, attn_mask=attn_mask)\n",
    "output.shape, attn.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = torch.rand(16, 50, 128)\n",
    "att_mask = (torch.rand((50, 50)) > .85).float()\n",
    "att_mask[att_mask == 1] = -np.inf\n",
    "\n",
    "mha = MultiheadAttention(d_model=128, n_heads=3, d_k=8, d_v=6)\n",
    "output, attn = mha(t, t, t, attn_mask=att_mask)\n",
    "test_eq(torch.isnan(output).sum().item(), 0)\n",
    "test_eq(torch.isnan(attn).sum().item(), 0)\n",
    "loss = output[:2, :].sum()\n",
    "test_eq(torch.isnan(loss).sum().item(), 0)\n",
    "loss.backward()\n",
    "for n, p in mha.named_parameters(): test_eq(torch.isnan(p.grad).sum().item(), 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = torch.rand(16, 50, 128)\n",
    "attn_mask = (torch.rand((50, 50)) > .85)\n",
    "\n",
    "# True values will be masked\n",
    "mha = MultiheadAttention(d_model=128, n_heads=3, d_k=8, d_v=6)\n",
    "output, attn = mha(t, t, t, attn_mask=att_mask)\n",
    "test_eq(torch.isnan(output).sum().item(), 0)\n",
    "test_eq(torch.isnan(attn).sum().item(), 0)\n",
    "loss = output[:2, :].sum()\n",
    "test_eq(torch.isnan(loss).sum().item(), 0)\n",
    "loss.backward()\n",
    "for n, p in mha.named_parameters(): test_eq(torch.isnan(p.grad).sum().item(), 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "class MultiConv1d(Module):\n",
    "    \"\"\"Module that applies multiple convolutions with different kernel sizes\"\"\"\n",
    "\n",
    "    def __init__(self, ni, nf=None, kss=[1,3,5,7], keep_original=False, separable=False, dim=1, **kwargs):\n",
    "        kss = listify(kss)\n",
    "        n_layers = len(kss)\n",
    "        if ni == nf: keep_original = False\n",
    "        if nf is None: nf = ni * (keep_original + n_layers)\n",
    "        nfs = [(nf - ni*keep_original) // n_layers] * n_layers\n",
    "        while np.sum(nfs) + ni * keep_original < nf:\n",
    "            for i in range(len(nfs)):\n",
    "                nfs[i] += 1\n",
    "                if np.sum(nfs) + ni * keep_original == nf: break\n",
    "\n",
    "        _conv = SeparableConv1d if separable else Conv1d\n",
    "        self.layers = nn.ModuleList()\n",
    "        for nfi,ksi in zip(nfs, kss):\n",
    "            self.layers.append(_conv(ni, nfi, ksi, **kwargs))\n",
    "        self.keep_original, self.dim = keep_original, dim\n",
    "\n",
    "    def forward(self, x):\n",
    "        output = [x] if self.keep_original else []\n",
    "        for l in self.layers:\n",
    "            output.append(l(x))\n",
    "        x = torch.cat(output, dim=self.dim)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = torch.rand(16, 6, 37)\n",
    "test_eq(MultiConv1d(6, None, kss=[1,3,5], keep_original=True)(t).shape, [16, 24, 37])\n",
    "test_eq(MultiConv1d(6, 36, kss=[1,3,5], keep_original=False)(t).shape, [16, 36, 37])\n",
    "test_eq(MultiConv1d(6, None, kss=[1,3,5], keep_original=True, dim=-1)(t).shape, [16, 6, 37*4])\n",
    "test_eq(MultiConv1d(6, 60, kss=[1,3,5], keep_original=True)(t).shape, [16, 60, 37])\n",
    "test_eq(MultiConv1d(6, 60, kss=[1,3,5], separable=True)(t).shape, [16, 60, 37])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class LSTMOutput(Module):\n",
    "    def forward(self, x): return x[0]\n",
    "    def __repr__(self): return f'{self.__class__.__name__}()'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = ([1], [2], [3])\n",
    "test_eq(LSTMOutput()(t), [1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class MultiEmbedding(Module):\n",
    "    def __init__(self, c_in, n_embeds, embed_dims=None, cat_pos=None, std=0.01):\n",
    "        if embed_dims is None: \n",
    "            embed_dims = [emb_sz_rule(s) for s in n_embeds]\n",
    "        else:\n",
    "            embed_dims = listify(embed_dims)\n",
    "            if len(embed_dims) == 1: embed_dims = embed_dims * len(n_embeds)\n",
    "            assert len(embed_dims) == len(n_embeds)\n",
    "        cat_pos = torch.as_tensor(listify(cat_pos)) if cat_pos else torch.arange(len(n_embeds))\n",
    "        self.register_buffer(\"cat_pos\", cat_pos)\n",
    "        cont_pos = torch.tensor([p for p in torch.arange(c_in) if p not in self.cat_pos])\n",
    "        self.register_buffer(\"cont_pos\", cont_pos)\n",
    "        self.cat_embed = nn.ModuleList([Embedding(n,d,std=std) for n,d in zip(n_embeds, embed_dims)])\n",
    "\n",
    "    def forward(self, x):\n",
    "        if isinstance(x, tuple): x_cat, x_cont, *_ = x\n",
    "        else: x_cat, x_cont = x[:, self.cat_pos], x[:, self.cont_pos]\n",
    "        x_cat = torch.cat([e(x_cat[:,i].long()).transpose(1,2) for i,e in enumerate(self.cat_embed)],1)\n",
    "        return torch.cat([x_cat, x_cont], 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[3, 4] [3, 3] torch.Size([4, 3, 10]) torch.Size([4, 7, 10])\n"
     ]
    }
   ],
   "source": [
    "a = alphabet[np.random.randint(0,3,40)]\n",
    "b = ALPHABET[np.random.randint(6,10,40)]\n",
    "c = np.random.rand(40).reshape(4,1,10)\n",
    "map_a = {k:v for v,k in enumerate(np.unique(a))}\n",
    "map_b = {k:v for v,k in enumerate(np.unique(b))}\n",
    "n_embeds = [len(m.keys()) for m in [map_a, map_b]]\n",
    "szs = [emb_sz_rule(n) for n in n_embeds]\n",
    "a = np.asarray(a.map(map_a)).reshape(4,1,10)\n",
    "b = np.asarray(b.map(map_b)).reshape(4,1,10)\n",
    "inp = torch.from_numpy(np.concatenate((c,a,b), 1)).float()\n",
    "memb = MultiEmbedding(3, n_embeds, cat_pos=[1,2])\n",
    "# registered buffers are part of the state_dict() but not module.parameters()\n",
    "assert all([(k in memb.state_dict().keys()) for k in ['cat_pos', 'cont_pos']])\n",
    "embeddings = memb(inp)\n",
    "print(n_embeds, szs, inp.shape, embeddings.shape)\n",
    "test_eq(embeddings.shape, (inp.shape[0],sum(szs)+1,inp.shape[-1]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/javascript": [
       "IPython.notebook.save_checkpoint();"
      ],
      "text/plain": [
       "<IPython.core.display.Javascript object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "100_models.layers.ipynb saved at 2021-11-29 17:00:17.\n",
      "Converted 100_models.layers.ipynb.\n",
      "\n",
      "\n",
      "Correct conversion! 😃\n",
      "Total time elapsed 0.093 s\n",
      "Monday 29/11/21 17:00:20 CET\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "                <audio  controls=\"controls\" autoplay=\"autoplay\">\n",
       "                    <source src=\"data:audio/wav;base64,UklGRvQHAABXQVZFZm10IBAAAAABAAEAECcAACBOAAACABAAZGF0YdAHAAAAAPF/iPh/gOoOon6w6ayCoR2ZeyfbjobxK+F2Hs0XjKc5i3DGvzaTlEaraE+zz5uLUl9f46fHpWJdxVSrnfmw8mYEScqUP70cb0Q8X41uysJ1si6Eh1jYzXp9IE2DzOYsftYRyoCY9dJ/8QICgIcEun8D9PmAaBPlfT7lq4MFIlh61tYPiCswIHX+yBaOqT1QbuW7qpVQSv9lu6+xnvRVSlyopAypbGBTUdSalrSTaUBFYpInwUpxOzhti5TOdndyKhCGrdwAfBUcXIJB69p+Vw1egB76+n9q/h6ADglbf4LvnIHfF/981ODThF4m8HiS0riJVjQ6c+/EOZCYQfJrGrhBmPVNMmNArLKhQlkXWYqhbaxXY8ZNHphLuBJsZUEckCTFVHMgNKGJytIDeSUmw4QN4Qx9pReTgb3vYX/TCBuApf75f+P5Y4CRDdN+B+tngk8c8nt03CKGqipgd13OhotwOC5x9MCAknFFcmlmtPmagFFFYOCo0qRzXMhVi57pryNmIEqJlRi8bm52PfuNM8k4dfQv+4cO12l6zCGdg3jl730uE/KAPvS+f0wEAoAsA89/XfXQgBESIn6S5luDtiC8eh/YmIfpLqt1OMp5jXg8/24MveqUNUnPZsqw0Z3yVDldnaUOqIZfXlKrm36zzWhjRhaT+r+ncHI5/otUzfd2uSt7hl/bqXtoHaCC6+mqfrAOeoDD+PJ/xf8RgLMHfH/b8GeBihZIfSXidoQSJWB52NM1iRkzz3MkxpKPbUCrbDu5d5fgTAxkSK3JoEhYD1p2omere2LZTuqYLbdWa49Cx5Dww7tyXDUnioXRkHhwJyKFvd/AfPoYy4Fl7j1/LQorgEr9/X89+0qAOAwAf13sJoL8Gkd8wt25hWIp3Heez/eKODfPcSPCzpFNRDVqf7UlmnNQKGHgqd+jgVvJVm2f265QZTpLS5byur1tpT6ajvrHq3Q2MXWIxtUCehoj8YMk5LB9hRQegeTypn+nBQWA0QHgf7f2q4C5EFt+5ucOg2YfHXtq2SSHpS0ydnTL4IxFO6pvNb4ulBdInWfcsfSc7VMmXpSmE6eeXmZThJxpsgRohEfOk86+AHCoOpOMFsx1dv8s6oYT2k17uR7ngpXod34IEJqAaPfnfyABCIBZBpl/NPI2gTQVjX134x2ExSPMeR7VtYjZMWJ0W8ftjkA/YW1durCWykvjZFKu4p9LVwVbZKNkqpxh6U+6mRC2mGq2Q3SRvsIgcpc2sIpD0Bp4uiiFhW3ecXxOGgaCDe0Vf4cLPoDv+/5/mfw1gN4KKX+17emBqBmYfBHfVYUZKFR44NBtiv41bHJUwx+RJkP1apu2VJlkTwli4qrwoo1ax1dToNCtemRSTBGXz7kJbdM/PY/Dxht0dTLziH7Ul3loJEiE0uJsfdsVTYGL8Yt/AgcMgHYA7X8S+IqAYA+QfjzpxIIVHnp7tdqzhmAstXaxzEqMETpScGC/dJP3Rmdo8LIZnOVSEF+Opxumsl1sVF+dVrE5Z6NIiZSkvVdv2zsqjdnK8HVDLlyHyNjuegogM4NA5z9+YRG9gA722H97AgOA/gSyf43zCIHdE899yuTIg3ciNXpm1jmImTDwdJPITI4RPhRugbvslbFKt2Vfr/6eTFb4W1WkY6m6YPdQjJr2tNZp3EQlko7BgXHRNz2LAc+gdwMq7IUf3R58ohtFgrbr6n7hDFWAlPr8f/T9I4CECU9/De+vgVQY5nxh4POEzybJeCTS5YnCNAZzhsRzkP1Bsmu4t4aYU07nYuerA6KWWcJYO6HHrKJjaE3Zl624UWz/QOOPjcWHc7QzdIk40yl5tCWjhIDhJX0xF4CBMvBsf10IF4Ac//Z/bPlsgAcOwn6S6n6CwxzUewLcRoYaKzV38M23i9o493CNwL6S1UUuaQe0QpvbUfdfiqglpcRccFU+nkWwambASUiVfLyqbg49xY2eyWh1hy/Sh37XjHpaIYKD7OUEfrgS5IC09MV/1gMBgKMDyH/n9N6AhhINfh7mdoMoIZt6r9fAh1cvfHXNya6N4DzDbqi8K5WWSYlmbbAdnkpV6FxJpWSo1V8DUmGb3rMRaQBG2JJgwN9wCDnNi8HNI3dKK1aG0dvHe/UciIJf6rt+Og5wgDn59X9P/xWAKQhxf2XweYH+FjB9suGVhIMlOnlo02GJhTOdc7vFyo/TQGxs2Li7lz9NwmPurBihnVi7WSWiwKvGYntOpJiOt5drKUKMkFnE8HLxNPmJ9NG4eP8mAYUv4Np8hhi3gdruSX+3CSWAwP38f8f6UoCuDPF+6Os8gnAbKnxQ3d2F0imydzDPKIuiN5lxu8EKkrFE82kftW2az1DbYImpMqTUW3FWIJ83r5hl2koJlla7+m0+PmSOZcjcdMgwS4g11iZ6qCLUg5jkxn0QFA6BWvOvfzEFBIBHAtp/Qfa3gC4RSH5y5yeD2B/8evnYS4cULgR2CMsUja47cG/QvW6UeEhXZ3+xP51GVNVdP6Zpp+1eDFM5nMeySWghR4+TNL85cD46YIyCzKJ2kCzEhoTabXtGHs+CCemJfpMPjoDe9+t/qQALgM8Gj3++8UaBqRV2fQTjO4Q3JKd5r9TgiEYyMHTxxiWPpz8jbfq585YpTJpk960xoKFXsVoTo7yq6GGMTw==\" type=\"audio/wav\" />\n",
       "                    Your browser does not support the audio element.\n",
       "                </audio>\n",
       "              "
      ],
      "text/plain": [
       "<IPython.lib.display.Audio object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "#hide\n",
    "from tsai.imports import create_scripts\n",
    "from tsai.export import get_nb_name\n",
    "nb_name = get_nb_name()\n",
    "create_scripts(nb_name);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
